diff --git a/machine-learning/src/main.py b/machine-learning/src/main.py index a656860475..cd6726f4d0 100644 --- a/machine-learning/src/main.py +++ b/machine-learning/src/main.py @@ -1,43 +1,58 @@ import os from flask import Flask, request from transformers import pipeline +from sentence_transformers import SentenceTransformer, util +from PIL import Image +is_dev = os.getenv('NODE_ENV') == 'development' +server_port = os.getenv('MACHINE_LEARNING_PORT', 3003) +server_host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0') + +classification_model = os.getenv('MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50') +object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny') +clip_image_model = os.getenv('MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32') +clip_text_model = os.getenv('MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32') + +_model_cache = {} +def _get_model(model, task=None): + global _model_cache + key = '|'.join([model, str(task)]) + if key not in _model_cache: + if task: + _model_cache[key] = pipeline(model=model, task=task) + else: + _model_cache[key] = SentenceTransformer(model) + return _model_cache[key] server = Flask(__name__) - -classifier = pipeline( - task="image-classification", - model="microsoft/resnet-50" -) - -detector = pipeline( - task="object-detection", - model="hustvl/yolos-tiny" -) - - -# Environment resolver -is_dev = os.getenv('NODE_ENV') == 'development' -server_port = os.getenv('MACHINE_LEARNING_PORT') or 3003 - - @server.route("/ping") def ping(): return "pong" - @server.route("/object-detection/detect-object", methods=['POST']) def object_detection(): + model = _get_model(object_model, 'object-detection') assetPath = request.json['thumbnailPath'] - return run_engine(detector, assetPath), 201 - + return run_engine(model, assetPath), 200 @server.route("/image-classifier/tag-image", methods=['POST']) def image_classification(): + model = _get_model(classification_model, 'image-classification') assetPath = request.json['thumbnailPath'] - return run_engine(classifier, assetPath), 201 + return run_engine(model, assetPath), 200 +@server.route("/sentence-transformer/encode-image", methods=['POST']) +def clip_encode_image(): + model = _get_model(clip_image_model) + assetPath = request.json['thumbnailPath'] + return model.encode(Image.open(assetPath)).tolist(), 200 + +@server.route("/sentence-transformer/encode-text", methods=['POST']) +def clip_encode_text(): + model = _get_model(clip_text_model) + text = request.json['text'] + return model.encode(text).tolist(), 200 def run_engine(engine, path): result = [] @@ -55,4 +70,4 @@ def run_engine(engine, path): if __name__ == "__main__": - server.run(debug=is_dev, host='0.0.0.0', port=server_port) + server.run(debug=is_dev, host=server_host, port=server_port) diff --git a/mobile/openapi/doc/SearchApi.md b/mobile/openapi/doc/SearchApi.md index 8faafbfabe..2a8a5733b6 100644 Binary files a/mobile/openapi/doc/SearchApi.md and b/mobile/openapi/doc/SearchApi.md differ diff --git a/mobile/openapi/lib/api/search_api.dart b/mobile/openapi/lib/api/search_api.dart index 6e7560b311..c50dc04890 100644 Binary files a/mobile/openapi/lib/api/search_api.dart and b/mobile/openapi/lib/api/search_api.dart differ diff --git a/mobile/openapi/test/search_api_test.dart b/mobile/openapi/test/search_api_test.dart index 8136969c91..ba9adaec59 100644 Binary files a/mobile/openapi/test/search_api_test.dart and b/mobile/openapi/test/search_api_test.dart differ diff --git a/server/apps/immich/src/api-v1/album/album.service.spec.ts b/server/apps/immich/src/api-v1/album/album.service.spec.ts index b347d85690..699c5e8db8 100644 --- a/server/apps/immich/src/api-v1/album/album.service.spec.ts +++ b/server/apps/immich/src/api-v1/album/album.service.spec.ts @@ -163,7 +163,7 @@ describe('Album service', () => { expect(result.id).toEqual(albumEntity.id); expect(result.albumName).toEqual(albumEntity.albumName); - expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: albumEntity } }); + expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [albumEntity.id] } }); }); it('gets list of albums for auth user', async () => { @@ -316,7 +316,7 @@ describe('Album service', () => { albumName: updatedAlbumName, albumThumbnailAssetId: updatedAlbumThumbnailAssetId, }); - expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: updatedAlbum } }); + expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [updatedAlbum.id] } }); }); it('prevents updating a not owned album (shared with auth user)', async () => { diff --git a/server/apps/immich/src/api-v1/album/album.service.ts b/server/apps/immich/src/api-v1/album/album.service.ts index 999b198985..f73aaf3061 100644 --- a/server/apps/immich/src/api-v1/album/album.service.ts +++ b/server/apps/immich/src/api-v1/album/album.service.ts @@ -59,7 +59,7 @@ export class AlbumService { async create(authUser: AuthUserDto, createAlbumDto: CreateAlbumDto): Promise<AlbumResponseDto> { const albumEntity = await this.albumRepository.create(authUser.id, createAlbumDto); - await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: albumEntity } }); + await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [albumEntity.id] } }); return mapAlbum(albumEntity); } @@ -107,7 +107,7 @@ export class AlbumService { } await this.albumRepository.delete(album); - await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ALBUM, data: { id: albumId } }); + await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ALBUM, data: { ids: [albumId] } }); } async removeUserFromAlbum(authUser: AuthUserDto, albumId: string, userId: string | 'me'): Promise<void> { @@ -171,7 +171,7 @@ export class AlbumService { const updatedAlbum = await this.albumRepository.updateAlbum(album, updateAlbumDto); - await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: updatedAlbum } }); + await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [updatedAlbum.id] } }); return mapAlbum(updatedAlbum); } diff --git a/server/apps/immich/src/api-v1/asset/asset.service.spec.ts b/server/apps/immich/src/api-v1/asset/asset.service.spec.ts index 6d2e5f5abb..a4875a95aa 100644 --- a/server/apps/immich/src/api-v1/asset/asset.service.spec.ts +++ b/server/apps/immich/src/api-v1/asset/asset.service.spec.ts @@ -455,8 +455,8 @@ describe('AssetService', () => { ]); expect(jobMock.queue.mock.calls).toEqual([ - [{ name: JobName.SEARCH_REMOVE_ASSET, data: { id: 'asset1' } }], - [{ name: JobName.SEARCH_REMOVE_ASSET, data: { id: 'asset2' } }], + [{ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: ['asset1'] } }], + [{ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: ['asset2'] } }], [ { name: JobName.DELETE_FILES, diff --git a/server/apps/immich/src/api-v1/asset/asset.service.ts b/server/apps/immich/src/api-v1/asset/asset.service.ts index f3070637d6..3f051fb223 100644 --- a/server/apps/immich/src/api-v1/asset/asset.service.ts +++ b/server/apps/immich/src/api-v1/asset/asset.service.ts @@ -170,7 +170,7 @@ export class AssetService { const updatedAsset = await this._assetRepository.update(authUser.id, asset, dto); - await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { asset: updatedAsset } }); + await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [assetId] } }); return mapAsset(updatedAsset); } @@ -251,8 +251,8 @@ export class AssetService { res.header('Cache-Control', 'none'); Logger.error(`Cannot create read stream for asset ${asset.id}`, 'getAssetThumbnail'); throw new InternalServerErrorException( - e, `Cannot read thumbnail file for asset ${asset.id} - contact your administrator`, + { cause: e as Error }, ); } } @@ -427,7 +427,7 @@ export class AssetService { try { await this._assetRepository.remove(asset); - await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ASSET, data: { id } }); + await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: [id] } }); result.push({ id, status: DeleteAssetStatusEnum.SUCCESS }); deleteQueue.push(asset.originalPath, asset.webpPath, asset.resizePath, asset.encodedVideoPath); diff --git a/server/apps/immich/src/api-v1/job/job.service.ts b/server/apps/immich/src/api-v1/job/job.service.ts index f41997ff1a..3a630130c3 100644 --- a/server/apps/immich/src/api-v1/job/job.service.ts +++ b/server/apps/immich/src/api-v1/job/job.service.ts @@ -70,6 +70,7 @@ export class JobService { for (const asset of assets) { await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } }); await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } }); + await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } }); } return assets.length; } diff --git a/server/apps/immich/src/controllers/search.controller.ts b/server/apps/immich/src/controllers/search.controller.ts index 2c2248c3fc..d172c7592b 100644 --- a/server/apps/immich/src/controllers/search.controller.ts +++ b/server/apps/immich/src/controllers/search.controller.ts @@ -20,7 +20,7 @@ export class SearchController { @Get() async search( @GetAuthUser() authUser: AuthUserDto, - @Query(new ValidationPipe({ transform: true })) dto: SearchDto, + @Query(new ValidationPipe({ transform: true })) dto: SearchDto | any, ): Promise<SearchResponseDto> { return this.searchService.search(authUser, dto); } diff --git a/server/apps/microservices/src/processors.ts b/server/apps/microservices/src/processors.ts index 63f05aed12..fe935744d7 100644 --- a/server/apps/microservices/src/processors.ts +++ b/server/apps/microservices/src/processors.ts @@ -1,10 +1,9 @@ import { AssetService, - IAlbumJob, IAssetJob, IAssetUploadedJob, + IBulkEntityJob, IDeleteFilesJob, - IDeleteJob, IUserDeletionJob, JobName, MediaService, @@ -53,15 +52,20 @@ export class BackgroundTaskProcessor { export class MachineLearningProcessor { constructor(private smartInfoService: SmartInfoService) {} - @Process({ name: JobName.IMAGE_TAGGING, concurrency: 2 }) + @Process({ name: JobName.IMAGE_TAGGING, concurrency: 1 }) async onTagImage(job: Job<IAssetJob>) { await this.smartInfoService.handleTagImage(job.data); } - @Process({ name: JobName.OBJECT_DETECTION, concurrency: 2 }) + @Process({ name: JobName.OBJECT_DETECTION, concurrency: 1 }) async onDetectObject(job: Job<IAssetJob>) { await this.smartInfoService.handleDetectObjects(job.data); } + + @Process({ name: JobName.ENCODE_CLIP, concurrency: 1 }) + async onEncodeClip(job: Job<IAssetJob>) { + await this.smartInfoService.handleEncodeClip(job.data); + } } @Processor(QueueName.SEARCH) @@ -79,23 +83,23 @@ export class SearchIndexProcessor { } @Process(JobName.SEARCH_INDEX_ALBUM) - async onIndexAlbum(job: Job<IAlbumJob>) { - await this.searchService.handleIndexAlbum(job.data); + onIndexAlbum(job: Job<IBulkEntityJob>) { + this.searchService.handleIndexAlbum(job.data); } @Process(JobName.SEARCH_INDEX_ASSET) - async onIndexAsset(job: Job<IAssetJob>) { - await this.searchService.handleIndexAsset(job.data); + onIndexAsset(job: Job<IBulkEntityJob>) { + this.searchService.handleIndexAsset(job.data); } @Process(JobName.SEARCH_REMOVE_ALBUM) - async onRemoveAlbum(job: Job<IDeleteJob>) { - await this.searchService.handleRemoveAlbum(job.data); + onRemoveAlbum(job: Job<IBulkEntityJob>) { + this.searchService.handleRemoveAlbum(job.data); } @Process(JobName.SEARCH_REMOVE_ASSET) - async onRemoveAsset(job: Job<IDeleteJob>) { - await this.searchService.handleRemoveAsset(job.data); + onRemoveAsset(job: Job<IBulkEntityJob>) { + this.searchService.handleRemoveAsset(job.data); } } diff --git a/server/immich-openapi-specs.json b/server/immich-openapi-specs.json index 2c21d6214a..26653e5daf 100644 --- a/server/immich-openapi-specs.json +++ b/server/immich-openapi-specs.json @@ -548,116 +548,7 @@ "get": { "operationId": "search", "description": "", - "parameters": [ - { - "name": "query", - "required": false, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "type", - "required": false, - "in": "query", - "schema": { - "enum": [ - "IMAGE", - "VIDEO", - "AUDIO", - "OTHER" - ], - "type": "string" - } - }, - { - "name": "isFavorite", - "required": false, - "in": "query", - "schema": { - "type": "boolean" - } - }, - { - "name": "exifInfo.city", - "required": false, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "exifInfo.state", - "required": false, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "exifInfo.country", - "required": false, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "exifInfo.make", - "required": false, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "exifInfo.model", - "required": false, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "smartInfo.objects", - "required": false, - "in": "query", - "schema": { - "type": "array", - "items": { - "type": "string" - } - } - }, - { - "name": "smartInfo.tags", - "required": false, - "in": "query", - "schema": { - "type": "array", - "items": { - "type": "string" - } - } - }, - { - "name": "recent", - "required": false, - "in": "query", - "schema": { - "type": "boolean" - } - }, - { - "name": "motion", - "required": false, - "in": "query", - "schema": { - "type": "boolean" - } - } - ], + "parameters": [], "responses": { "200": { "description": "", diff --git a/server/libs/domain/src/album/album.repository.ts b/server/libs/domain/src/album/album.repository.ts index adc62ea971..424b901776 100644 --- a/server/libs/domain/src/album/album.repository.ts +++ b/server/libs/domain/src/album/album.repository.ts @@ -3,6 +3,7 @@ import { AlbumEntity } from '@app/infra/db/entities'; export const IAlbumRepository = 'IAlbumRepository'; export interface IAlbumRepository { + getByIds(ids: string[]): Promise<AlbumEntity[]>; deleteAll(userId: string): Promise<void>; getAll(): Promise<AlbumEntity[]>; save(album: Partial<AlbumEntity>): Promise<AlbumEntity>; diff --git a/server/libs/domain/src/asset/asset.core.ts b/server/libs/domain/src/asset/asset.core.ts index 46b4231ff4..164c373809 100644 --- a/server/libs/domain/src/asset/asset.core.ts +++ b/server/libs/domain/src/asset/asset.core.ts @@ -11,7 +11,10 @@ export class AssetCore { async save(asset: Partial<AssetEntity>) { const _asset = await this.assetRepository.save(asset); - await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { asset: _asset } }); + await this.jobRepository.queue({ + name: JobName.SEARCH_INDEX_ASSET, + data: { ids: [_asset.id] }, + }); return _asset; } diff --git a/server/libs/domain/src/asset/asset.repository.ts b/server/libs/domain/src/asset/asset.repository.ts index 0173cc4ee1..9b75eb8148 100644 --- a/server/libs/domain/src/asset/asset.repository.ts +++ b/server/libs/domain/src/asset/asset.repository.ts @@ -7,6 +7,7 @@ export interface AssetSearchOptions { export const IAssetRepository = 'IAssetRepository'; export interface IAssetRepository { + getByIds(ids: string[]): Promise<AssetEntity[]>; deleteAll(ownerId: string): Promise<void>; getAll(options?: AssetSearchOptions): Promise<AssetEntity[]>; save(asset: Partial<AssetEntity>): Promise<AssetEntity>; diff --git a/server/libs/domain/src/asset/asset.service.spec.ts b/server/libs/domain/src/asset/asset.service.spec.ts index 536a0c148c..608ce4985a 100644 --- a/server/libs/domain/src/asset/asset.service.spec.ts +++ b/server/libs/domain/src/asset/asset.service.spec.ts @@ -54,7 +54,7 @@ describe(AssetService.name, () => { expect(assetMock.save).toHaveBeenCalledWith(assetEntityStub.image); expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ASSET, - data: { asset: assetEntityStub.image }, + data: { ids: [assetEntityStub.image.id] }, }); }); }); diff --git a/server/libs/domain/src/job/job.constants.ts b/server/libs/domain/src/job/job.constants.ts index 52ee425720..0404f33ddd 100644 --- a/server/libs/domain/src/job/job.constants.ts +++ b/server/libs/domain/src/job/job.constants.ts @@ -29,4 +29,5 @@ export enum JobName { SEARCH_INDEX_ALBUM = 'search-index-album', SEARCH_REMOVE_ALBUM = 'search-remove-album', SEARCH_REMOVE_ASSET = 'search-remove-asset', + ENCODE_CLIP = 'clip-encode', } diff --git a/server/libs/domain/src/job/job.interface.ts b/server/libs/domain/src/job/job.interface.ts index 0810bdad07..ad21fb1484 100644 --- a/server/libs/domain/src/job/job.interface.ts +++ b/server/libs/domain/src/job/job.interface.ts @@ -8,15 +8,15 @@ export interface IAssetJob { asset: AssetEntity; } +export interface IBulkEntityJob { + ids: string[]; +} + export interface IAssetUploadedJob { asset: AssetEntity; fileName: string; } -export interface IDeleteJob { - id: string; -} - export interface IDeleteFilesJob { files: Array<string | null | undefined>; } diff --git a/server/libs/domain/src/job/job.repository.ts b/server/libs/domain/src/job/job.repository.ts index 0867f5391c..d9f72586c8 100644 --- a/server/libs/domain/src/job/job.repository.ts +++ b/server/libs/domain/src/job/job.repository.ts @@ -1,10 +1,9 @@ import { JobName, QueueName } from './job.constants'; import { - IAlbumJob, IAssetJob, IAssetUploadedJob, + IBulkEntityJob, IDeleteFilesJob, - IDeleteJob, IReverseGeocodingJob, IUserDeletionJob, } from './job.interface'; @@ -31,13 +30,14 @@ export type JobItem = | { name: JobName.EXTRACT_VIDEO_METADATA; data: IAssetUploadedJob } | { name: JobName.OBJECT_DETECTION; data: IAssetJob } | { name: JobName.IMAGE_TAGGING; data: IAssetJob } + | { name: JobName.ENCODE_CLIP; data: IAssetJob } | { name: JobName.DELETE_FILES; data: IDeleteFilesJob } | { name: JobName.SEARCH_INDEX_ASSETS } - | { name: JobName.SEARCH_INDEX_ASSET; data: IAssetJob } + | { name: JobName.SEARCH_INDEX_ASSET; data: IBulkEntityJob } | { name: JobName.SEARCH_INDEX_ALBUMS } - | { name: JobName.SEARCH_INDEX_ALBUM; data: IAlbumJob } - | { name: JobName.SEARCH_REMOVE_ASSET; data: IDeleteJob } - | { name: JobName.SEARCH_REMOVE_ALBUM; data: IDeleteJob }; + | { name: JobName.SEARCH_INDEX_ALBUM; data: IBulkEntityJob } + | { name: JobName.SEARCH_REMOVE_ASSET; data: IBulkEntityJob } + | { name: JobName.SEARCH_REMOVE_ALBUM; data: IBulkEntityJob }; export const IJobRepository = 'IJobRepository'; diff --git a/server/libs/domain/src/media/media.service.ts b/server/libs/domain/src/media/media.service.ts index 32ec126dec..97f84162cf 100644 --- a/server/libs/domain/src/media/media.service.ts +++ b/server/libs/domain/src/media/media.service.ts @@ -54,6 +54,7 @@ export class MediaService { await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } }); await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } }); await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } }); + await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } }); this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset)); } @@ -72,6 +73,7 @@ export class MediaService { await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } }); await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } }); await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } }); + await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } }); this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset)); } catch (error: any) { diff --git a/server/libs/domain/src/search/dto/search.dto.ts b/server/libs/domain/src/search/dto/search.dto.ts index 1610e2e713..83fbcb5de7 100644 --- a/server/libs/domain/src/search/dto/search.dto.ts +++ b/server/libs/domain/src/search/dto/search.dto.ts @@ -4,11 +4,21 @@ import { IsArray, IsBoolean, IsEnum, IsNotEmpty, IsOptional, IsString } from 'cl import { toBoolean } from '../../../../../apps/immich/src/utils/transform.util'; export class SearchDto { + @IsString() + @IsNotEmpty() + @IsOptional() + q?: string; + @IsString() @IsNotEmpty() @IsOptional() query?: string; + @IsBoolean() + @IsOptional() + @Transform(toBoolean) + clip?: boolean; + @IsEnum(AssetType) @IsOptional() type?: AssetType; diff --git a/server/libs/domain/src/search/search.repository.ts b/server/libs/domain/src/search/search.repository.ts index 4508b14514..8db6eb14b4 100644 --- a/server/libs/domain/src/search/search.repository.ts +++ b/server/libs/domain/src/search/search.repository.ts @@ -5,6 +5,11 @@ export enum SearchCollection { ALBUMS = 'albums', } +export enum SearchStrategy { + CLIP = 'CLIP', + TEXT = 'TEXT', +} + export interface SearchFilter { id?: string; userId: string; @@ -19,6 +24,7 @@ export interface SearchFilter { tags?: string[]; recent?: boolean; motion?: boolean; + debug?: boolean; } export interface SearchResult<T> { @@ -57,16 +63,15 @@ export interface ISearchRepository { setup(): Promise<void>; checkMigrationStatus(): Promise<SearchCollectionIndexStatus>; - index(collection: SearchCollection.ASSETS, item: AssetEntity): Promise<void>; - index(collection: SearchCollection.ALBUMS, item: AlbumEntity): Promise<void>; + importAlbums(items: AlbumEntity[], done: boolean): Promise<void>; + importAssets(items: AssetEntity[], done: boolean): Promise<void>; - delete(collection: SearchCollection, id: string): Promise<void>; + deleteAlbums(ids: string[]): Promise<void>; + deleteAssets(ids: string[]): Promise<void>; - import(collection: SearchCollection.ASSETS, items: AssetEntity[], done: boolean): Promise<void>; - import(collection: SearchCollection.ALBUMS, items: AlbumEntity[], done: boolean): Promise<void>; - - search(collection: SearchCollection.ASSETS, query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>; - search(collection: SearchCollection.ALBUMS, query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>; + searchAlbums(query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>; + searchAssets(query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>; + vectorSearch(query: number[], filters: SearchFilter): Promise<SearchResult<AssetEntity>>; explore(userId: string): Promise<SearchExploreItem<AssetEntity>[]>; } diff --git a/server/libs/domain/src/search/search.service.spec.ts b/server/libs/domain/src/search/search.service.spec.ts index 813091f8d8..ff27ffdec8 100644 --- a/server/libs/domain/src/search/search.service.spec.ts +++ b/server/libs/domain/src/search/search.service.spec.ts @@ -4,25 +4,32 @@ import { plainToInstance } from 'class-transformer'; import { albumStub, assetEntityStub, + asyncTick, authStub, newAlbumRepositoryMock, newAssetRepositoryMock, newJobRepositoryMock, + newMachineLearningRepositoryMock, newSearchRepositoryMock, + searchStub, } from '../../test'; import { IAlbumRepository } from '../album/album.repository'; import { IAssetRepository } from '../asset/asset.repository'; import { JobName } from '../job'; import { IJobRepository } from '../job/job.repository'; +import { IMachineLearningRepository } from '../smart-info'; import { SearchDto } from './dto'; import { ISearchRepository } from './search.repository'; import { SearchService } from './search.service'; +jest.useFakeTimers(); + describe(SearchService.name, () => { let sut: SearchService; let albumMock: jest.Mocked<IAlbumRepository>; let assetMock: jest.Mocked<IAssetRepository>; let jobMock: jest.Mocked<IJobRepository>; + let machineMock: jest.Mocked<IMachineLearningRepository>; let searchMock: jest.Mocked<ISearchRepository>; let configMock: jest.Mocked<ConfigService>; @@ -30,10 +37,15 @@ describe(SearchService.name, () => { albumMock = newAlbumRepositoryMock(); assetMock = newAssetRepositoryMock(); jobMock = newJobRepositoryMock(); + machineMock = newMachineLearningRepositoryMock(); searchMock = newSearchRepositoryMock(); configMock = { get: jest.fn() } as unknown as jest.Mocked<ConfigService>; - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); + sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); + }); + + afterEach(() => { + sut.teardown(); }); it('should work', () => { @@ -69,7 +81,7 @@ describe(SearchService.name, () => { it('should be disabled via an env variable', () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); expect(sut.isEnabled()).toBe(false); }); @@ -82,7 +94,7 @@ describe(SearchService.name, () => { it('should return the config when search is disabled', () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); expect(sut.getConfig()).toEqual({ enabled: false }); }); @@ -91,13 +103,15 @@ describe(SearchService.name, () => { describe(`bootstrap`, () => { it('should skip when search is disabled', async () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); await sut.bootstrap(); expect(searchMock.setup).not.toHaveBeenCalled(); expect(searchMock.checkMigrationStatus).not.toHaveBeenCalled(); expect(jobMock.queue).not.toHaveBeenCalled(); + + sut.teardown(); }); it('should skip schema migration if not needed', async () => { @@ -123,21 +137,18 @@ describe(SearchService.name, () => { describe('search', () => { it('should throw an error is search is disabled', async () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); await expect(sut.search(authStub.admin, {})).rejects.toBeInstanceOf(BadRequestException); - expect(searchMock.search).not.toHaveBeenCalled(); + expect(searchMock.searchAlbums).not.toHaveBeenCalled(); + expect(searchMock.searchAssets).not.toHaveBeenCalled(); }); it('should search assets and albums', async () => { - searchMock.search.mockResolvedValue({ - total: 0, - count: 0, - page: 1, - items: [], - facets: [], - }); + searchMock.searchAssets.mockResolvedValue(searchStub.emptyResults); + searchMock.searchAlbums.mockResolvedValue(searchStub.emptyResults); + searchMock.vectorSearch.mockResolvedValue(searchStub.emptyResults); await expect(sut.search(authStub.admin, {})).resolves.toEqual({ albums: { @@ -156,162 +167,158 @@ describe(SearchService.name, () => { }, }); - expect(searchMock.search.mock.calls).toEqual([ - ['assets', '*', { userId: authStub.admin.id }], - ['albums', '*', { userId: authStub.admin.id }], - ]); + // expect(searchMock.searchAssets).toHaveBeenCalledWith('*', { userId: authStub.admin.id }); + expect(searchMock.searchAlbums).toHaveBeenCalledWith('*', { userId: authStub.admin.id }); }); }); describe('handleIndexAssets', () => { - it('should skip if search is disabled', async () => { - configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); - - await sut.handleIndexAssets(); - - expect(searchMock.import).not.toHaveBeenCalled(); - }); - it('should index all the assets', async () => { - assetMock.getAll.mockResolvedValue([]); + assetMock.getAll.mockResolvedValue([assetEntityStub.image]); await sut.handleIndexAssets(); - expect(searchMock.import).toHaveBeenCalledWith('assets', [], true); + expect(searchMock.importAssets).toHaveBeenCalledWith([assetEntityStub.image], true); }); it('should log an error', async () => { - assetMock.getAll.mockResolvedValue([]); - searchMock.import.mockRejectedValue(new Error('import failed')); + assetMock.getAll.mockResolvedValue([assetEntityStub.image]); + searchMock.importAssets.mockRejectedValue(new Error('import failed')); await sut.handleIndexAssets(); + + expect(searchMock.importAssets).toHaveBeenCalled(); + }); + + it('should skip if search is disabled', async () => { + configMock.get.mockReturnValue('false'); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); + + await sut.handleIndexAssets(); + + expect(searchMock.importAssets).not.toHaveBeenCalled(); + expect(searchMock.importAlbums).not.toHaveBeenCalled(); }); }); describe('handleIndexAsset', () => { - it('should skip if search is disabled', async () => { + it('should skip if search is disabled', () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); - - await sut.handleIndexAsset({ asset: assetEntityStub.image }); - - expect(searchMock.index).not.toHaveBeenCalled(); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); + sut.handleIndexAsset({ ids: [assetEntityStub.image.id] }); }); - it('should index the asset', async () => { - await sut.handleIndexAsset({ asset: assetEntityStub.image }); - - expect(searchMock.index).toHaveBeenCalledWith('assets', assetEntityStub.image); - }); - - it('should log an error', async () => { - searchMock.index.mockRejectedValue(new Error('index failed')); - - await sut.handleIndexAsset({ asset: assetEntityStub.image }); - - expect(searchMock.index).toHaveBeenCalled(); + it('should index the asset', () => { + sut.handleIndexAsset({ ids: [assetEntityStub.image.id] }); }); }); describe('handleIndexAlbums', () => { - it('should skip if search is disabled', async () => { + it('should skip if search is disabled', () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); - - await sut.handleIndexAlbums(); - - expect(searchMock.import).not.toHaveBeenCalled(); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); + sut.handleIndexAlbums(); }); it('should index all the albums', async () => { - albumMock.getAll.mockResolvedValue([]); + albumMock.getAll.mockResolvedValue([albumStub.empty]); await sut.handleIndexAlbums(); - expect(searchMock.import).toHaveBeenCalledWith('albums', [], true); + expect(searchMock.importAlbums).toHaveBeenCalledWith([albumStub.empty], true); }); it('should log an error', async () => { - albumMock.getAll.mockResolvedValue([]); - searchMock.import.mockRejectedValue(new Error('import failed')); + albumMock.getAll.mockResolvedValue([albumStub.empty]); + searchMock.importAlbums.mockRejectedValue(new Error('import failed')); await sut.handleIndexAlbums(); + + expect(searchMock.importAlbums).toHaveBeenCalled(); }); }); describe('handleIndexAlbum', () => { - it('should skip if search is disabled', async () => { + it('should skip if search is disabled', () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); - - await sut.handleIndexAlbum({ album: albumStub.empty }); - - expect(searchMock.index).not.toHaveBeenCalled(); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); + sut.handleIndexAlbum({ ids: [albumStub.empty.id] }); }); - it('should index the album', async () => { - await sut.handleIndexAlbum({ album: albumStub.empty }); - - expect(searchMock.index).toHaveBeenCalledWith('albums', albumStub.empty); - }); - - it('should log an error', async () => { - searchMock.index.mockRejectedValue(new Error('index failed')); - - await sut.handleIndexAlbum({ album: albumStub.empty }); - - expect(searchMock.index).toHaveBeenCalled(); + it('should index the album', () => { + sut.handleIndexAlbum({ ids: [albumStub.empty.id] }); }); }); describe('handleRemoveAlbum', () => { - it('should skip if search is disabled', async () => { + it('should skip if search is disabled', () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); - - await sut.handleRemoveAlbum({ id: 'album1' }); - - expect(searchMock.delete).not.toHaveBeenCalled(); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); + sut.handleRemoveAlbum({ ids: ['album1'] }); }); - it('should remove the album', async () => { - await sut.handleRemoveAlbum({ id: 'album1' }); - - expect(searchMock.delete).toHaveBeenCalledWith('albums', 'album1'); - }); - - it('should log an error', async () => { - searchMock.delete.mockRejectedValue(new Error('remove failed')); - - await sut.handleRemoveAlbum({ id: 'album1' }); - - expect(searchMock.delete).toHaveBeenCalled(); + it('should remove the album', () => { + sut.handleRemoveAlbum({ ids: ['album1'] }); }); }); describe('handleRemoveAsset', () => { - it('should skip if search is disabled', async () => { + it('should skip if search is disabled', () => { configMock.get.mockReturnValue('false'); - sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); - - await sut.handleRemoveAsset({ id: 'asset1`' }); - - expect(searchMock.delete).not.toHaveBeenCalled(); + const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock); + sut.handleRemoveAsset({ ids: ['asset1'] }); }); - it('should remove the asset', async () => { - await sut.handleRemoveAsset({ id: 'asset1' }); + it('should remove the asset', () => { + sut.handleRemoveAsset({ ids: ['asset1'] }); + }); + }); - expect(searchMock.delete).toHaveBeenCalledWith('assets', 'asset1'); + describe('flush', () => { + it('should flush queued album updates', async () => { + albumMock.getByIds.mockResolvedValue([albumStub.empty]); + + sut.handleIndexAlbum({ ids: ['album1'] }); + + jest.runOnlyPendingTimers(); + + await asyncTick(4); + + expect(albumMock.getByIds).toHaveBeenCalledWith(['album1']); + expect(searchMock.importAlbums).toHaveBeenCalledWith([albumStub.empty], false); }); - it('should log an error', async () => { - searchMock.delete.mockRejectedValue(new Error('remove failed')); + it('should flush queued album deletes', async () => { + sut.handleRemoveAlbum({ ids: ['album1'] }); - await sut.handleRemoveAsset({ id: 'asset1' }); + jest.runOnlyPendingTimers(); - expect(searchMock.delete).toHaveBeenCalled(); + await asyncTick(4); + + expect(searchMock.deleteAlbums).toHaveBeenCalledWith(['album1']); + }); + + it('should flush queued asset updates', async () => { + assetMock.getByIds.mockResolvedValue([assetEntityStub.image]); + + sut.handleIndexAsset({ ids: ['asset1'] }); + + jest.runOnlyPendingTimers(); + + await asyncTick(4); + + expect(assetMock.getByIds).toHaveBeenCalledWith(['asset1']); + expect(searchMock.importAssets).toHaveBeenCalledWith([assetEntityStub.image], false); + }); + + it('should flush queued asset deletes', async () => { + sut.handleRemoveAsset({ ids: ['asset1'] }); + + jest.runOnlyPendingTimers(); + + await asyncTick(4); + + expect(searchMock.deleteAssets).toHaveBeenCalledWith(['asset1']); }); }); }); diff --git a/server/libs/domain/src/search/search.service.ts b/server/libs/domain/src/search/search.service.ts index f350e19b45..3e93f468a5 100644 --- a/server/libs/domain/src/search/search.service.ts +++ b/server/libs/domain/src/search/search.service.ts @@ -1,27 +1,64 @@ -import { AssetEntity } from '@app/infra/db/entities'; +import { MACHINE_LEARNING_ENABLED } from '@app/common'; +import { AlbumEntity, AssetEntity } from '@app/infra/db/entities'; import { BadRequestException, Inject, Injectable, Logger } from '@nestjs/common'; import { ConfigService } from '@nestjs/config'; +import { mapAlbum } from '../album'; import { IAlbumRepository } from '../album/album.repository'; +import { mapAsset } from '../asset'; import { IAssetRepository } from '../asset/asset.repository'; import { AuthUserDto } from '../auth'; -import { IAlbumJob, IAssetJob, IDeleteJob, IJobRepository, JobName } from '../job'; +import { IBulkEntityJob, IJobRepository, JobName } from '../job'; +import { IMachineLearningRepository } from '../smart-info'; import { SearchDto } from './dto'; import { SearchConfigResponseDto, SearchResponseDto } from './response-dto'; -import { ISearchRepository, SearchCollection, SearchExploreItem } from './search.repository'; +import { + ISearchRepository, + SearchCollection, + SearchExploreItem, + SearchResult, + SearchStrategy, +} from './search.repository'; + +interface SyncQueue { + upsert: Set<string>; + delete: Set<string>; +} @Injectable() export class SearchService { private logger = new Logger(SearchService.name); private enabled: boolean; + private timer: NodeJS.Timer | null = null; + + private albumQueue: SyncQueue = { + upsert: new Set(), + delete: new Set(), + }; + + private assetQueue: SyncQueue = { + upsert: new Set(), + delete: new Set(), + }; constructor( @Inject(IAlbumRepository) private albumRepository: IAlbumRepository, @Inject(IAssetRepository) private assetRepository: IAssetRepository, @Inject(IJobRepository) private jobRepository: IJobRepository, + @Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository, @Inject(ISearchRepository) private searchRepository: ISearchRepository, configService: ConfigService, ) { this.enabled = configService.get('TYPESENSE_ENABLED') !== 'false'; + if (this.enabled) { + this.timer = setInterval(() => this.flush(), 5_000); + } + } + + teardown() { + if (this.timer) { + clearInterval(this.timer); + this.timer = null; + } } isEnabled() { @@ -61,103 +98,131 @@ export class SearchService { async search(authUser: AuthUserDto, dto: SearchDto): Promise<SearchResponseDto> { this.assertEnabled(); - const query = dto.query || '*'; + const query = dto.q || dto.query || '*'; + const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT; + const filters = { userId: authUser.id, ...dto }; + + let assets: SearchResult<AssetEntity>; + switch (strategy) { + case SearchStrategy.TEXT: + assets = await this.searchRepository.searchAssets(query, filters); + break; + case SearchStrategy.CLIP: + default: + if (!MACHINE_LEARNING_ENABLED) { + throw new BadRequestException('Machine Learning is disabled'); + } + const clip = await this.machineLearning.encodeText(query); + assets = await this.searchRepository.vectorSearch(clip, filters); + } + + const albums = await this.searchRepository.searchAlbums(query, filters); return { - assets: (await this.searchRepository.search(SearchCollection.ASSETS, query, { - userId: authUser.id, - ...dto, - })) as any, - albums: (await this.searchRepository.search(SearchCollection.ALBUMS, query, { - userId: authUser.id, - ...dto, - })) as any, + albums: { ...albums, items: albums.items.map(mapAlbum) }, + assets: { ...assets, items: assets.items.map(mapAsset) }, }; } - async handleIndexAssets() { - if (!this.enabled) { - return; - } - - try { - this.logger.debug(`Running indexAssets`); - // TODO: do this in batches based on searchIndexVersion - const assets = await this.assetRepository.getAll({ isVisible: true }); - - this.logger.log(`Indexing ${assets.length} assets`); - await this.searchRepository.import(SearchCollection.ASSETS, assets, true); - this.logger.debug('Finished re-indexing all assets'); - } catch (error: any) { - this.logger.error(`Unable to index all assets`, error?.stack); - } - } - - async handleIndexAsset(data: IAssetJob) { - if (!this.enabled) { - return; - } - - const { asset } = data; - if (!asset.isVisible) { - return; - } - - try { - await this.searchRepository.index(SearchCollection.ASSETS, asset); - } catch (error: any) { - this.logger.error(`Unable to index asset: ${asset.id}`, error?.stack); - } - } - async handleIndexAlbums() { if (!this.enabled) { return; } try { - const albums = await this.albumRepository.getAll(); + const albums = this.patchAlbums(await this.albumRepository.getAll()); this.logger.log(`Indexing ${albums.length} albums`); - await this.searchRepository.import(SearchCollection.ALBUMS, albums, true); - this.logger.debug('Finished re-indexing all albums'); + await this.searchRepository.importAlbums(albums, true); } catch (error: any) { this.logger.error(`Unable to index all albums`, error?.stack); } } - async handleIndexAlbum(data: IAlbumJob) { + async handleIndexAssets() { if (!this.enabled) { return; } - const { album } = data; - try { - await this.searchRepository.index(SearchCollection.ALBUMS, album); + // TODO: do this in batches based on searchIndexVersion + const assets = this.patchAssets(await this.assetRepository.getAll({ isVisible: true })); + this.logger.log(`Indexing ${assets.length} assets`); + await this.searchRepository.importAssets(assets, true); + this.logger.debug('Finished re-indexing all assets'); } catch (error: any) { - this.logger.error(`Unable to index album: ${album.id}`, error?.stack); + this.logger.error(`Unable to index all assets`, error?.stack); } } - async handleRemoveAlbum(data: IDeleteJob) { - await this.handleRemove(SearchCollection.ALBUMS, data); - } - - async handleRemoveAsset(data: IDeleteJob) { - await this.handleRemove(SearchCollection.ASSETS, data); - } - - private async handleRemove(collection: SearchCollection, data: IDeleteJob) { + handleIndexAlbum({ ids }: IBulkEntityJob) { if (!this.enabled) { return; } - const { id } = data; + for (const id of ids) { + this.albumQueue.upsert.add(id); + } + } - try { - await this.searchRepository.delete(collection, id); - } catch (error: any) { - this.logger.error(`Unable to remove ${collection}: ${id}`, error?.stack); + handleIndexAsset({ ids }: IBulkEntityJob) { + if (!this.enabled) { + return; + } + + for (const id of ids) { + this.assetQueue.upsert.add(id); + } + } + + handleRemoveAlbum({ ids }: IBulkEntityJob) { + if (!this.enabled) { + return; + } + + for (const id of ids) { + this.albumQueue.delete.add(id); + } + } + + handleRemoveAsset({ ids }: IBulkEntityJob) { + if (!this.enabled) { + return; + } + + for (const id of ids) { + this.assetQueue.delete.add(id); + } + } + + private async flush() { + if (this.albumQueue.upsert.size > 0) { + const ids = [...this.albumQueue.upsert.keys()]; + const items = await this.idsToAlbums(ids); + this.logger.debug(`Flushing ${items.length} album upserts`); + await this.searchRepository.importAlbums(items, false); + this.albumQueue.upsert.clear(); + } + + if (this.albumQueue.delete.size > 0) { + const ids = [...this.albumQueue.delete.keys()]; + this.logger.debug(`Flushing ${ids.length} album deletes`); + await this.searchRepository.deleteAlbums(ids); + this.albumQueue.delete.clear(); + } + + if (this.assetQueue.upsert.size > 0) { + const ids = [...this.assetQueue.upsert.keys()]; + const items = await this.idsToAssets(ids); + this.logger.debug(`Flushing ${items.length} asset upserts`); + await this.searchRepository.importAssets(items, false); + this.assetQueue.upsert.clear(); + } + + if (this.assetQueue.delete.size > 0) { + const ids = [...this.assetQueue.delete.keys()]; + this.logger.debug(`Flushing ${ids.length} asset deletes`); + await this.searchRepository.deleteAssets(ids); + this.assetQueue.delete.clear(); } } @@ -166,4 +231,22 @@ export class SearchService { throw new BadRequestException('Search is disabled'); } } + + private async idsToAlbums(ids: string[]): Promise<AlbumEntity[]> { + const entities = await this.albumRepository.getByIds(ids); + return this.patchAlbums(entities); + } + + private async idsToAssets(ids: string[]): Promise<AssetEntity[]> { + const entities = await this.assetRepository.getByIds(ids); + return this.patchAssets(entities.filter((entity) => entity.isVisible)); + } + + private patchAssets(assets: AssetEntity[]): AssetEntity[] { + return assets; + } + + private patchAlbums(albums: AlbumEntity[]): AlbumEntity[] { + return albums.map((entity) => ({ ...entity, assets: [] })); + } } diff --git a/server/libs/domain/src/smart-info/machine-learning.interface.ts b/server/libs/domain/src/smart-info/machine-learning.interface.ts index a175890814..a42bf30231 100644 --- a/server/libs/domain/src/smart-info/machine-learning.interface.ts +++ b/server/libs/domain/src/smart-info/machine-learning.interface.ts @@ -7,4 +7,6 @@ export interface MachineLearningInput { export interface IMachineLearningRepository { tagImage(input: MachineLearningInput): Promise<string[]>; detectObjects(input: MachineLearningInput): Promise<string[]>; + encodeImage(input: MachineLearningInput): Promise<number[]>; + encodeText(input: string): Promise<number[]>; } diff --git a/server/libs/domain/src/smart-info/smart-info.service.spec.ts b/server/libs/domain/src/smart-info/smart-info.service.spec.ts index 7d859ba8b4..41e3887b6b 100644 --- a/server/libs/domain/src/smart-info/smart-info.service.spec.ts +++ b/server/libs/domain/src/smart-info/smart-info.service.spec.ts @@ -1,5 +1,6 @@ import { AssetEntity } from '@app/infra/db/entities'; -import { newMachineLearningRepositoryMock, newSmartInfoRepositoryMock } from '../../test'; +import { newJobRepositoryMock, newMachineLearningRepositoryMock, newSmartInfoRepositoryMock } from '../../test'; +import { IJobRepository } from '../job'; import { IMachineLearningRepository } from './machine-learning.interface'; import { ISmartInfoRepository } from './smart-info.repository'; import { SmartInfoService } from './smart-info.service'; @@ -11,13 +12,15 @@ const asset = { describe(SmartInfoService.name, () => { let sut: SmartInfoService; + let jobMock: jest.Mocked<IJobRepository>; let smartMock: jest.Mocked<ISmartInfoRepository>; let machineMock: jest.Mocked<IMachineLearningRepository>; beforeEach(async () => { smartMock = newSmartInfoRepositoryMock(); + jobMock = newJobRepositoryMock(); machineMock = newMachineLearningRepositoryMock(); - sut = new SmartInfoService(smartMock, machineMock); + sut = new SmartInfoService(jobMock, smartMock, machineMock); }); it('should work', () => { diff --git a/server/libs/domain/src/smart-info/smart-info.service.ts b/server/libs/domain/src/smart-info/smart-info.service.ts index f3185e58f6..2621576eed 100644 --- a/server/libs/domain/src/smart-info/smart-info.service.ts +++ b/server/libs/domain/src/smart-info/smart-info.service.ts @@ -1,6 +1,6 @@ import { MACHINE_LEARNING_ENABLED } from '@app/common'; import { Inject, Injectable, Logger } from '@nestjs/common'; -import { IAssetJob } from '../job'; +import { IAssetJob, IJobRepository, JobName } from '../job'; import { IMachineLearningRepository } from './machine-learning.interface'; import { ISmartInfoRepository } from './smart-info.repository'; @@ -9,6 +9,7 @@ export class SmartInfoService { private logger = new Logger(SmartInfoService.name); constructor( + @Inject(IJobRepository) private jobRepository: IJobRepository, @Inject(ISmartInfoRepository) private repository: ISmartInfoRepository, @Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository, ) {} @@ -24,6 +25,7 @@ export class SmartInfoService { const tags = await this.machineLearning.tagImage({ thumbnailPath: asset.resizePath }); if (tags.length > 0) { await this.repository.upsert({ assetId: asset.id, tags }); + await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } }); } } catch (error: any) { this.logger.error(`Unable to run image tagging pipeline: ${asset.id}`, error?.stack); @@ -41,9 +43,26 @@ export class SmartInfoService { const objects = await this.machineLearning.detectObjects({ thumbnailPath: asset.resizePath }); if (objects.length > 0) { await this.repository.upsert({ assetId: asset.id, objects }); + await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } }); } } catch (error: any) { this.logger.error(`Unable run object detection pipeline: ${asset.id}`, error?.stack); } } + + async handleEncodeClip(data: IAssetJob) { + const { asset } = data; + + if (!MACHINE_LEARNING_ENABLED || !asset.resizePath) { + return; + } + + try { + const clipEmbedding = await this.machineLearning.encodeImage({ thumbnailPath: asset.resizePath }); + await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding }); + await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } }); + } catch (error: any) { + this.logger.error(`Unable run clip encoding pipeline: ${asset.id}`, error?.stack); + } + } } diff --git a/server/libs/domain/test/album.repository.mock.ts b/server/libs/domain/test/album.repository.mock.ts index dc21e5ecbb..2c4a5500a2 100644 --- a/server/libs/domain/test/album.repository.mock.ts +++ b/server/libs/domain/test/album.repository.mock.ts @@ -2,6 +2,7 @@ import { IAlbumRepository } from '../src'; export const newAlbumRepositoryMock = (): jest.Mocked<IAlbumRepository> => { return { + getByIds: jest.fn(), deleteAll: jest.fn(), getAll: jest.fn(), save: jest.fn(), diff --git a/server/libs/domain/test/asset.repository.mock.ts b/server/libs/domain/test/asset.repository.mock.ts index d65477d126..b56bd1419b 100644 --- a/server/libs/domain/test/asset.repository.mock.ts +++ b/server/libs/domain/test/asset.repository.mock.ts @@ -2,6 +2,7 @@ import { IAssetRepository } from '../src'; export const newAssetRepositoryMock = (): jest.Mocked<IAssetRepository> => { return { + getByIds: jest.fn(), getAll: jest.fn(), deleteAll: jest.fn(), save: jest.fn(), diff --git a/server/libs/domain/test/fixtures.ts b/server/libs/domain/test/fixtures.ts index 27efcd2174..b9d9a6cbe1 100644 --- a/server/libs/domain/test/fixtures.ts +++ b/server/libs/domain/test/fixtures.ts @@ -15,6 +15,7 @@ import { AuthUserDto, ExifResponseDto, mapUser, + SearchResult, SharedLinkResponseDto, } from '../src'; @@ -448,6 +449,7 @@ export const sharedLinkStub = { tags: [], objects: ['a', 'b', 'c'], asset: null as any, + clipEmbedding: [0.12, 0.13, 0.14], }, webpPath: '', encodedVideoPath: '', @@ -550,3 +552,13 @@ export const sharedLinkResponseStub = { // TODO - the constructor isn't used anywhere, so not test coverage new ExifResponseDto(); + +export const searchStub = { + emptyResults: Object.freeze<SearchResult<any>>({ + total: 0, + count: 0, + page: 1, + items: [], + facets: [], + }), +}; diff --git a/server/libs/domain/test/index.ts b/server/libs/domain/test/index.ts index aec1d0c15a..45f5825464 100644 --- a/server/libs/domain/test/index.ts +++ b/server/libs/domain/test/index.ts @@ -13,3 +13,9 @@ export * from './storage.repository.mock'; export * from './system-config.repository.mock'; export * from './user-token.repository.mock'; export * from './user.repository.mock'; + +export async function asyncTick(steps: number) { + for (let i = 0; i < steps; i++) { + await Promise.resolve(); + } +} diff --git a/server/libs/domain/test/machine-learning.repository.mock.ts b/server/libs/domain/test/machine-learning.repository.mock.ts index 7c7c2b2553..0bc06814ec 100644 --- a/server/libs/domain/test/machine-learning.repository.mock.ts +++ b/server/libs/domain/test/machine-learning.repository.mock.ts @@ -4,5 +4,7 @@ export const newMachineLearningRepositoryMock = (): jest.Mocked<IMachineLearning return { tagImage: jest.fn(), detectObjects: jest.fn(), + encodeImage: jest.fn(), + encodeText: jest.fn(), }; }; diff --git a/server/libs/domain/test/search.repository.mock.ts b/server/libs/domain/test/search.repository.mock.ts index 0ba2dd4f9c..5a4fcdf217 100644 --- a/server/libs/domain/test/search.repository.mock.ts +++ b/server/libs/domain/test/search.repository.mock.ts @@ -4,10 +4,13 @@ export const newSearchRepositoryMock = (): jest.Mocked<ISearchRepository> => { return { setup: jest.fn(), checkMigrationStatus: jest.fn(), - index: jest.fn(), - import: jest.fn(), - search: jest.fn(), - delete: jest.fn(), + importAssets: jest.fn(), + importAlbums: jest.fn(), + deleteAlbums: jest.fn(), + deleteAssets: jest.fn(), + searchAssets: jest.fn(), + searchAlbums: jest.fn(), + vectorSearch: jest.fn(), explore: jest.fn(), }; }; diff --git a/server/libs/infra/src/db/entities/smart-info.entity.ts b/server/libs/infra/src/db/entities/smart-info.entity.ts index ae3edd8404..20edfaf43c 100644 --- a/server/libs/infra/src/db/entities/smart-info.entity.ts +++ b/server/libs/infra/src/db/entities/smart-info.entity.ts @@ -15,4 +15,14 @@ export class SmartInfoEntity { @Column({ type: 'text', array: true, nullable: true }) objects!: string[] | null; + + @Column({ + type: 'numeric', + array: true, + nullable: true, + // note: migration generator is broken for numeric[], but these _are_ set in the database + // precision: 20, + // scale: 19, + }) + clipEmbedding!: number[] | null; } diff --git a/server/libs/infra/src/db/migrations/1677971458822-AddCLIPEncodeDataColumn.ts b/server/libs/infra/src/db/migrations/1677971458822-AddCLIPEncodeDataColumn.ts new file mode 100644 index 0000000000..82f8176b0d --- /dev/null +++ b/server/libs/infra/src/db/migrations/1677971458822-AddCLIPEncodeDataColumn.ts @@ -0,0 +1,13 @@ +import { MigrationInterface, QueryRunner } from 'typeorm'; + +export class AddCLIPEncodeDataColumn1677971458822 implements MigrationInterface { + name = 'AddCLIPEncodeDataColumn1677971458822'; + + public async up(queryRunner: QueryRunner): Promise<void> { + await queryRunner.query(`ALTER TABLE "smart_info" ADD "clipEmbedding" numeric(20,19) array`); + } + + public async down(queryRunner: QueryRunner): Promise<void> { + await queryRunner.query(`ALTER TABLE "smart_info" DROP COLUMN "clipEmbedding"`); + } +} diff --git a/server/libs/infra/src/db/repository/album.repository.ts b/server/libs/infra/src/db/repository/album.repository.ts index d4eca4e500..9542227fd4 100644 --- a/server/libs/infra/src/db/repository/album.repository.ts +++ b/server/libs/infra/src/db/repository/album.repository.ts @@ -1,19 +1,34 @@ import { IAlbumRepository } from '@app/domain'; import { Injectable } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; -import { Repository } from 'typeorm'; +import { In, Repository } from 'typeorm'; import { AlbumEntity } from '../entities'; @Injectable() export class AlbumRepository implements IAlbumRepository { constructor(@InjectRepository(AlbumEntity) private repository: Repository<AlbumEntity>) {} + getByIds(ids: string[]): Promise<AlbumEntity[]> { + return this.repository.find({ + where: { + id: In(ids), + }, + relations: { + owner: true, + }, + }); + } + async deleteAll(userId: string): Promise<void> { await this.repository.delete({ ownerId: userId }); } getAll(): Promise<AlbumEntity[]> { - return this.repository.find(); + return this.repository.find({ + relations: { + owner: true, + }, + }); } async save(album: Partial<AlbumEntity>) { diff --git a/server/libs/infra/src/db/repository/asset.repository.ts b/server/libs/infra/src/db/repository/asset.repository.ts index 6f0e65684e..17adb46bb3 100644 --- a/server/libs/infra/src/db/repository/asset.repository.ts +++ b/server/libs/infra/src/db/repository/asset.repository.ts @@ -1,13 +1,24 @@ import { AssetSearchOptions, IAssetRepository } from '@app/domain'; import { Injectable } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; -import { Not, Repository } from 'typeorm'; +import { In, Not, Repository } from 'typeorm'; import { AssetEntity, AssetType } from '../entities'; @Injectable() export class AssetRepository implements IAssetRepository { constructor(@InjectRepository(AssetEntity) private repository: Repository<AssetEntity>) {} + getByIds(ids: string[]): Promise<AssetEntity[]> { + return this.repository.find({ + where: { id: In(ids) }, + relations: { + exifInfo: true, + smartInfo: true, + tags: true, + }, + }); + } + async deleteAll(ownerId: string): Promise<void> { await this.repository.delete({ ownerId }); } diff --git a/server/libs/infra/src/job/job.repository.ts b/server/libs/infra/src/job/job.repository.ts index e83ce06034..53f3fb9a78 100644 --- a/server/libs/infra/src/job/job.repository.ts +++ b/server/libs/infra/src/job/job.repository.ts @@ -41,6 +41,7 @@ export class JobRepository implements IJobRepository { case JobName.OBJECT_DETECTION: case JobName.IMAGE_TAGGING: + case JobName.ENCODE_CLIP: await this.machineLearning.add(item.name, item.data); break; @@ -73,7 +74,7 @@ export class JobRepository implements IJobRepository { case JobName.SEARCH_INDEX_ASSETS: case JobName.SEARCH_INDEX_ALBUMS: - await this.searchIndex.add(item.name); + await this.searchIndex.add(item.name, {}); break; case JobName.SEARCH_INDEX_ASSET: diff --git a/server/libs/infra/src/machine-learning/machine-learning.repository.ts b/server/libs/infra/src/machine-learning/machine-learning.repository.ts index e69e068ecc..a69d687445 100644 --- a/server/libs/infra/src/machine-learning/machine-learning.repository.ts +++ b/server/libs/infra/src/machine-learning/machine-learning.repository.ts @@ -14,4 +14,12 @@ export class MachineLearningRepository implements IMachineLearningRepository { detectObjects(input: MachineLearningInput): Promise<string[]> { return client.post<string[]>('/object-detection/detect-object', input).then((res) => res.data); } + + encodeImage(input: MachineLearningInput): Promise<number[]> { + return client.post<number[]>('/sentence-transformer/encode-image', input).then((res) => res.data); + } + + encodeText(input: string): Promise<number[]> { + return client.post<number[]>('/sentence-transformer/encode-text', { text: input }).then((res) => res.data); + } } diff --git a/server/libs/infra/src/search/schemas/asset.schema.ts b/server/libs/infra/src/search/schemas/asset.schema.ts index d379048c97..28a61cf097 100644 --- a/server/libs/infra/src/search/schemas/asset.schema.ts +++ b/server/libs/infra/src/search/schemas/asset.schema.ts @@ -1,6 +1,6 @@ import { CollectionCreateSchema } from 'typesense/lib/Typesense/Collections'; -export const assetSchemaVersion = 2; +export const assetSchemaVersion = 3; export const assetSchema: CollectionCreateSchema = { name: `assets-v${assetSchemaVersion}`, fields: [ @@ -29,6 +29,7 @@ export const assetSchema: CollectionCreateSchema = { // smart info { name: 'smartInfo.objects', type: 'string[]', facet: true, optional: true }, { name: 'smartInfo.tags', type: 'string[]', facet: true, optional: true }, + { name: 'smartInfo.clipEmbedding', type: 'float[]', facet: false, optional: true, num_dim: 512 }, // computed { name: 'geo', type: 'geopoint', facet: false, optional: true }, diff --git a/server/libs/infra/src/search/typesense.repository.ts b/server/libs/infra/src/search/typesense.repository.ts index a656d4b24e..f5bb23d8eb 100644 --- a/server/libs/infra/src/search/typesense.repository.ts +++ b/server/libs/infra/src/search/typesense.repository.ts @@ -16,12 +16,7 @@ import { AlbumEntity, AssetEntity } from '../db'; import { albumSchema } from './schemas/album.schema'; import { assetSchema } from './schemas/asset.schema'; -interface CustomAssetEntity extends AssetEntity { - geo?: [number, number]; - motion?: boolean; -} - -function removeNil<T extends Dictionary<any>>(item: T): Partial<T> { +function removeNil<T extends Dictionary<any>>(item: T): T { _.forOwn(item, (value, key) => { if (_.isNil(value) || (_.isObject(value) && !_.isDate(value) && _.isEmpty(removeNil(value)))) { delete item[key]; @@ -31,6 +26,11 @@ function removeNil<T extends Dictionary<any>>(item: T): Partial<T> { return item; } +interface CustomAssetEntity extends AssetEntity { + geo?: [number, number]; + motion?: boolean; +} + const schemaMap: Record<SearchCollection, CollectionCreateSchema> = { [SearchCollection.ASSETS]: assetSchema, [SearchCollection.ALBUMS]: albumSchema, @@ -38,24 +38,9 @@ const schemaMap: Record<SearchCollection, CollectionCreateSchema> = { const schemas = Object.entries(schemaMap) as [SearchCollection, CollectionCreateSchema][]; -interface SearchUpdateQueue<T = any> { - upsert: T[]; - delete: string[]; -} - @Injectable() export class TypesenseRepository implements ISearchRepository { private logger = new Logger(TypesenseRepository.name); - private queue: Record<SearchCollection, SearchUpdateQueue> = { - [SearchCollection.ASSETS]: { - upsert: [], - delete: [], - }, - [SearchCollection.ALBUMS]: { - upsert: [], - delete: [], - }, - }; private _client: Client | null = null; private get client(): Client { @@ -83,8 +68,6 @@ export class TypesenseRepository implements ISearchRepository { numRetries: 3, connectionTimeoutSeconds: 10, }); - - setInterval(() => this.flush(), 5_000); } async setup(): Promise<void> { @@ -131,48 +114,27 @@ export class TypesenseRepository implements ISearchRepository { return migrationMap; } - async index(collection: SearchCollection, item: AssetEntity | AlbumEntity, immediate?: boolean): Promise<void> { - const schema = schemaMap[collection]; - - if (collection === SearchCollection.ASSETS) { - item = this.patchAsset(item as AssetEntity); - } - - if (immediate) { - await this.client.collections(schema.name).documents().upsert(item); - return; - } - - this.queue[collection].upsert.push(item); + async importAlbums(items: AlbumEntity[], done: boolean): Promise<void> { + await this.import(SearchCollection.ALBUMS, items, done); } - async delete(collection: SearchCollection, id: string, immediate?: boolean): Promise<void> { - const schema = schemaMap[collection]; - - if (immediate) { - await this.client.collections(schema.name).documents().delete(id); - return; - } - - this.queue[collection].delete.push(id); + async importAssets(items: AssetEntity[], done: boolean): Promise<void> { + await this.import(SearchCollection.ASSETS, items, done); } - async import(collection: SearchCollection, items: AssetEntity[] | AlbumEntity[], done: boolean): Promise<void> { + private async import( + collection: SearchCollection, + items: AlbumEntity[] | AssetEntity[], + done: boolean, + ): Promise<void> { try { - const schema = schemaMap[collection]; - const _items = items.map((item) => { - if (collection === SearchCollection.ASSETS) { - item = this.patchAsset(item as AssetEntity); - } - // null values are invalid for typesense documents - return removeNil(item); - }); - if (_items.length > 0) { - await this.client - .collections(schema.name) - .documents() - .import(_items, { action: 'upsert', dirty_values: 'coerce_or_drop' }); + if (items.length > 0) { + await this.client.collections(schemaMap[collection].name).documents().import(this.patch(collection, items), { + action: 'upsert', + dirty_values: 'coerce_or_drop', + }); } + if (done) { await this.updateAlias(collection); } @@ -234,71 +196,81 @@ export class TypesenseRepository implements ISearchRepository { ); } - search(collection: SearchCollection.ASSETS, query: string, filter: SearchFilter): Promise<SearchResult<AssetEntity>>; - search(collection: SearchCollection.ALBUMS, query: string, filter: SearchFilter): Promise<SearchResult<AlbumEntity>>; - async search(collection: SearchCollection, query: string, filters: SearchFilter) { - const alias = await this.client.aliases(collection).retrieve(); - - const { userId } = filters; - - const _filters = [`ownerId:${userId}`]; - - if (filters.id) { - _filters.push(`id:=${filters.id}`); - } - if (collection === SearchCollection.ASSETS) { - for (const item of schemaMap[collection].fields || []) { - let value = filters[item.name as keyof SearchFilter]; - if (Array.isArray(value)) { - value = `[${value.join(',')}]`; - } - if (item.facet && value !== undefined) { - _filters.push(`${item.name}:${value}`); - } - } - - this.logger.debug(`Searching query='${query}', filters='${JSON.stringify(_filters)}'`); - - const results = await this.client - .collections<AssetEntity>(alias.collection_name) - .documents() - .search({ - q: query, - query_by: [ - 'exifInfo.imageName', - 'exifInfo.country', - 'exifInfo.state', - 'exifInfo.city', - 'exifInfo.description', - 'smartInfo.tags', - 'smartInfo.objects', - ].join(','), - filter_by: _filters.join(' && '), - per_page: 250, - sort_by: filters.recent ? 'createdAt:desc' : undefined, - facet_by: this.getFacetFieldNames(SearchCollection.ASSETS), - }); - - return this.asResponse(results); - } - - if (collection === SearchCollection.ALBUMS) { - const results = await this.client - .collections<AlbumEntity>(alias.collection_name) - .documents() - .search({ - q: query, - query_by: 'albumName', - filter_by: _filters.join(','), - }); - - return this.asResponse(results); - } - - throw new Error(`Invalid collection: ${collection}`); + async deleteAlbums(ids: string[]): Promise<void> { + await this.delete(SearchCollection.ALBUMS, ids); } - private asResponse<T extends DocumentSchema>(results: SearchResponse<T>): SearchResult<T> { + async deleteAssets(ids: string[]): Promise<void> { + await this.delete(SearchCollection.ASSETS, ids); + } + + async delete(collection: SearchCollection, ids: string[]): Promise<void> { + await this.client + .collections(schemaMap[collection].name) + .documents() + .delete({ filter_by: `id: [${ids.join(',')}]` }); + } + + async searchAlbums(query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>> { + const alias = await this.client.aliases(SearchCollection.ALBUMS).retrieve(); + + const results = await this.client + .collections<AlbumEntity>(alias.collection_name) + .documents() + .search({ + q: query, + query_by: 'albumName', + filter_by: this.getAlbumFilters(filters), + }); + + return this.asResponse(results, filters.debug); + } + + async searchAssets(query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>> { + const alias = await this.client.aliases(SearchCollection.ASSETS).retrieve(); + const results = await this.client + .collections<AssetEntity>(alias.collection_name) + .documents() + .search({ + q: query, + query_by: [ + 'exifInfo.imageName', + 'exifInfo.country', + 'exifInfo.state', + 'exifInfo.city', + 'exifInfo.description', + 'smartInfo.tags', + 'smartInfo.objects', + ].join(','), + per_page: 250, + facet_by: this.getFacetFieldNames(SearchCollection.ASSETS), + filter_by: this.getAssetFilters(filters), + sort_by: filters.recent ? 'createdAt:desc' : undefined, + }); + + return this.asResponse(results, filters.debug); + } + + async vectorSearch(input: number[], filters: SearchFilter): Promise<SearchResult<AssetEntity>> { + const alias = await this.client.aliases(SearchCollection.ASSETS).retrieve(); + + const { results } = await this.client.multiSearch.perform({ + searches: [ + { + collection: alias.collection_name, + q: '*', + vector_query: `smartInfo.clipEmbedding:([${input.join(',')}], k:100)`, + per_page: 250, + facet_by: this.getFacetFieldNames(SearchCollection.ASSETS), + filter_by: this.getAssetFilters(filters), + } as any, + ], + }); + + return this.asResponse(results[0] as SearchResponse<AssetEntity>, filters.debug); + } + + private asResponse<T extends DocumentSchema>(results: SearchResponse<T>, debug?: boolean): SearchResult<T> { return { page: results.page, total: results.found, @@ -308,51 +280,23 @@ export class TypesenseRepository implements ISearchRepository { counts: facet.counts.map((item) => ({ count: item.count, value: item.value })), fieldName: facet.field_name as string, })), - }; + debug: debug ? results : undefined, + } as SearchResult<T>; } - private async flush() { - for (const [collection, schema] of schemas) { - if (this.queue[collection].upsert.length > 0) { - try { - const items = this.queue[collection].upsert.map((item) => removeNil(item)); - this.logger.debug(`Flushing ${items.length} ${collection} upserts to typesense`); - await this.client - .collections(schema.name) - .documents() - .import(items, { action: 'upsert', dirty_values: 'coerce_or_drop' }); - this.queue[collection].upsert = []; - } catch (error) { - this.handleError(error); - } - } - - if (this.queue[collection].delete.length > 0) { - try { - const items = this.queue[collection].delete; - this.logger.debug(`Flushing ${items.length} ${collection} deletes to typesense`); - await this.client - .collections(schema.name) - .documents() - .delete({ filter_by: `id: [${items.join(',')}]` }); - this.queue[collection].delete = []; - } catch (error) { - this.handleError(error); - } - } - } - } - - private handleError(error: any): never { + private handleError(error: any) { this.logger.error('Unable to index documents'); const results = error.importResults || []; for (const result of results) { try { result.document = JSON.parse(result.document); + if (result.document?.smartInfo?.clipEmbedding) { + result.document.smartInfo.clipEmbedding = '<truncated>'; + } } catch {} } + this.logger.verbose(JSON.stringify(results, null, 2)); - throw error; } private async updateAlias(collection: SearchCollection) { @@ -373,6 +317,18 @@ export class TypesenseRepository implements ISearchRepository { } } + private patch(collection: SearchCollection, items: AssetEntity[] | AlbumEntity[]) { + return items.map((item) => + collection === SearchCollection.ASSETS + ? this.patchAsset(item as AssetEntity) + : this.patchAlbum(item as AlbumEntity), + ); + } + + private patchAlbum(album: AlbumEntity): AlbumEntity { + return removeNil(album); + } + private patchAsset(asset: AssetEntity): CustomAssetEntity { let custom = asset as CustomAssetEntity; @@ -382,9 +338,7 @@ export class TypesenseRepository implements ISearchRepository { custom = { ...custom, geo: [lat, lng] }; } - custom = { ...custom, motion: !!asset.livePhotoVideoId }; - - return custom; + return removeNil({ ...custom, motion: !!asset.livePhotoVideoId }); } private getFacetFieldNames(collection: SearchCollection) { @@ -393,4 +347,41 @@ export class TypesenseRepository implements ISearchRepository { .map((field) => field.name) .join(','); } + + private getAlbumFilters(filters: SearchFilter) { + const { userId } = filters; + const _filters = [`ownerId:${userId}`]; + if (filters.id) { + _filters.push(`id:=${filters.id}`); + } + + for (const item of albumSchema.fields || []) { + let value = filters[item.name as keyof SearchFilter]; + if (Array.isArray(value)) { + value = `[${value.join(',')}]`; + } + if (item.facet && value !== undefined) { + _filters.push(`${item.name}:${value}`); + } + } + + return _filters.join(' && '); + } + + private getAssetFilters(filters: SearchFilter) { + const _filters = [`ownerId:${filters.userId}`]; + if (filters.id) { + _filters.push(`id:=${filters.id}`); + } + for (const item of assetSchema.fields || []) { + let value = filters[item.name as keyof SearchFilter]; + if (Array.isArray(value)) { + value = `[${value.join(',')}]`; + } + if (item.facet && value !== undefined) { + _filters.push(`${item.name}:${value}`); + } + } + return _filters.join(' && '); + } } diff --git a/server/package-lock.json b/server/package-lock.json index 5e26c539cf..f0dbb73610 100644 --- a/server/package-lock.json +++ b/server/package-lock.json @@ -48,7 +48,7 @@ "sanitize-filename": "^1.6.3", "sharp": "^0.28.0", "typeorm": "^0.3.11", - "typesense": "^1.5.2" + "typesense": "^1.5.3" }, "bin": { "immich": "bin/cli.sh" @@ -11137,9 +11137,9 @@ } }, "node_modules/typesense": { - "version": "1.5.2", - "resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.2.tgz", - "integrity": "sha512-ysARFw+4z3AdSViOACqf7K9TXoP2wAXd5p5uSGTdXW14UYjcEzpV/S/EhMoiC6YdZyrnbDdNsxgWbf+AWJ9Udw==", + "version": "1.5.3", + "resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.3.tgz", + "integrity": "sha512-eLHBP6AHex04tT+q/a7Uc+dFjIuoKTRpvlsNJwVTyedh4n0qnJxbfoLJBCxzhhZn5eITjEK0oWvVZ5byc3E+Ww==", "dependencies": { "axios": "^0.26.0", "loglevel": "^1.8.0" @@ -20023,9 +20023,9 @@ "devOptional": true }, "typesense": { - "version": "1.5.2", - "resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.2.tgz", - "integrity": "sha512-ysARFw+4z3AdSViOACqf7K9TXoP2wAXd5p5uSGTdXW14UYjcEzpV/S/EhMoiC6YdZyrnbDdNsxgWbf+AWJ9Udw==", + "version": "1.5.3", + "resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.3.tgz", + "integrity": "sha512-eLHBP6AHex04tT+q/a7Uc+dFjIuoKTRpvlsNJwVTyedh4n0qnJxbfoLJBCxzhhZn5eITjEK0oWvVZ5byc3E+Ww==", "requires": { "axios": "^0.26.0", "loglevel": "^1.8.0" diff --git a/server/package.json b/server/package.json index 1e3b8f9105..2b8d19121f 100644 --- a/server/package.json +++ b/server/package.json @@ -78,7 +78,7 @@ "sanitize-filename": "^1.6.3", "sharp": "^0.28.0", "typeorm": "^0.3.11", - "typesense": "^1.5.2" + "typesense": "^1.5.3" }, "devDependencies": { "@nestjs/cli": "^9.1.8", diff --git a/web/src/api/open-api/api.ts b/web/src/api/open-api/api.ts index 69a66a5679..5b514cc3f4 100644 --- a/web/src/api/open-api/api.ts +++ b/web/src/api/open-api/api.ts @@ -6739,22 +6739,10 @@ export const SearchApiAxiosParamCreator = function (configuration?: Configuratio }, /** * - * @param {string} [query] - * @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type] - * @param {boolean} [isFavorite] - * @param {string} [exifInfoCity] - * @param {string} [exifInfoState] - * @param {string} [exifInfoCountry] - * @param {string} [exifInfoMake] - * @param {string} [exifInfoModel] - * @param {Array<string>} [smartInfoObjects] - * @param {Array<string>} [smartInfoTags] - * @param {boolean} [recent] - * @param {boolean} [motion] * @param {*} [options] Override http request option. * @throws {RequiredError} */ - search: async (query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options: AxiosRequestConfig = {}): Promise<RequestArgs> => { + search: async (options: AxiosRequestConfig = {}): Promise<RequestArgs> => { const localVarPath = `/search`; // use dummy base URL string because the URL constructor only accepts absolute URLs. const localVarUrlObj = new URL(localVarPath, DUMMY_BASE_URL); @@ -6773,54 +6761,6 @@ export const SearchApiAxiosParamCreator = function (configuration?: Configuratio // authentication cookie required - if (query !== undefined) { - localVarQueryParameter['query'] = query; - } - - if (type !== undefined) { - localVarQueryParameter['type'] = type; - } - - if (isFavorite !== undefined) { - localVarQueryParameter['isFavorite'] = isFavorite; - } - - if (exifInfoCity !== undefined) { - localVarQueryParameter['exifInfo.city'] = exifInfoCity; - } - - if (exifInfoState !== undefined) { - localVarQueryParameter['exifInfo.state'] = exifInfoState; - } - - if (exifInfoCountry !== undefined) { - localVarQueryParameter['exifInfo.country'] = exifInfoCountry; - } - - if (exifInfoMake !== undefined) { - localVarQueryParameter['exifInfo.make'] = exifInfoMake; - } - - if (exifInfoModel !== undefined) { - localVarQueryParameter['exifInfo.model'] = exifInfoModel; - } - - if (smartInfoObjects) { - localVarQueryParameter['smartInfo.objects'] = smartInfoObjects; - } - - if (smartInfoTags) { - localVarQueryParameter['smartInfo.tags'] = smartInfoTags; - } - - if (recent !== undefined) { - localVarQueryParameter['recent'] = recent; - } - - if (motion !== undefined) { - localVarQueryParameter['motion'] = motion; - } - setSearchParams(localVarUrlObj, localVarQueryParameter); @@ -6862,23 +6802,11 @@ export const SearchApiFp = function(configuration?: Configuration) { }, /** * - * @param {string} [query] - * @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type] - * @param {boolean} [isFavorite] - * @param {string} [exifInfoCity] - * @param {string} [exifInfoState] - * @param {string} [exifInfoCountry] - * @param {string} [exifInfoMake] - * @param {string} [exifInfoModel] - * @param {Array<string>} [smartInfoObjects] - * @param {Array<string>} [smartInfoTags] - * @param {boolean} [recent] - * @param {boolean} [motion] * @param {*} [options] Override http request option. * @throws {RequiredError} */ - async search(query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options?: AxiosRequestConfig): Promise<(axios?: AxiosInstance, basePath?: string) => AxiosPromise<SearchResponseDto>> { - const localVarAxiosArgs = await localVarAxiosParamCreator.search(query, type, isFavorite, exifInfoCity, exifInfoState, exifInfoCountry, exifInfoMake, exifInfoModel, smartInfoObjects, smartInfoTags, recent, motion, options); + async search(options?: AxiosRequestConfig): Promise<(axios?: AxiosInstance, basePath?: string) => AxiosPromise<SearchResponseDto>> { + const localVarAxiosArgs = await localVarAxiosParamCreator.search(options); return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration); }, } @@ -6909,23 +6837,11 @@ export const SearchApiFactory = function (configuration?: Configuration, basePat }, /** * - * @param {string} [query] - * @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type] - * @param {boolean} [isFavorite] - * @param {string} [exifInfoCity] - * @param {string} [exifInfoState] - * @param {string} [exifInfoCountry] - * @param {string} [exifInfoMake] - * @param {string} [exifInfoModel] - * @param {Array<string>} [smartInfoObjects] - * @param {Array<string>} [smartInfoTags] - * @param {boolean} [recent] - * @param {boolean} [motion] * @param {*} [options] Override http request option. * @throws {RequiredError} */ - search(query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options?: any): AxiosPromise<SearchResponseDto> { - return localVarFp.search(query, type, isFavorite, exifInfoCity, exifInfoState, exifInfoCountry, exifInfoMake, exifInfoModel, smartInfoObjects, smartInfoTags, recent, motion, options).then((request) => request(axios, basePath)); + search(options?: any): AxiosPromise<SearchResponseDto> { + return localVarFp.search(options).then((request) => request(axios, basePath)); }, }; }; @@ -6959,24 +6875,12 @@ export class SearchApi extends BaseAPI { /** * - * @param {string} [query] - * @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type] - * @param {boolean} [isFavorite] - * @param {string} [exifInfoCity] - * @param {string} [exifInfoState] - * @param {string} [exifInfoCountry] - * @param {string} [exifInfoMake] - * @param {string} [exifInfoModel] - * @param {Array<string>} [smartInfoObjects] - * @param {Array<string>} [smartInfoTags] - * @param {boolean} [recent] - * @param {boolean} [motion] * @param {*} [options] Override http request option. * @throws {RequiredError} * @memberof SearchApi */ - public search(query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options?: AxiosRequestConfig) { - return SearchApiFp(this.configuration).search(query, type, isFavorite, exifInfoCity, exifInfoState, exifInfoCountry, exifInfoMake, exifInfoModel, smartInfoObjects, smartInfoTags, recent, motion, options).then((request) => request(this.axios, this.basePath)); + public search(options?: AxiosRequestConfig) { + return SearchApiFp(this.configuration).search(options).then((request) => request(this.axios, this.basePath)); } } diff --git a/web/src/lib/components/shared-components/search-bar/search-bar.svelte b/web/src/lib/components/shared-components/search-bar/search-bar.svelte index d6518daed2..fb9b53c425 100644 --- a/web/src/lib/components/shared-components/search-bar/search-bar.svelte +++ b/web/src/lib/components/shared-components/search-bar/search-bar.svelte @@ -15,7 +15,8 @@ function onSearch() { const params = new URLSearchParams({ - q: value + q: value, + clip: 'true' }); goto(`${AppRoute.SEARCH}?${params}`, { replaceState: replaceHistoryState }); diff --git a/web/src/routes/(user)/search/+page.server.ts b/web/src/routes/(user)/search/+page.server.ts index fdc4824d05..71122f4582 100644 --- a/web/src/routes/(user)/search/+page.server.ts +++ b/web/src/routes/(user)/search/+page.server.ts @@ -7,22 +7,9 @@ export const load = (async ({ locals, parent, url }) => { throw redirect(302, '/auth/login'); } - const term = url.searchParams.get('q') || undefined; - const { data: results } = await locals.api.searchApi.search( - term, - undefined, - undefined, - undefined, - undefined, - undefined, - undefined, - undefined, - undefined, - undefined, - undefined, - undefined, - { params: url.searchParams } - ); + const term = url.searchParams.get('q') || url.searchParams.get('query') || undefined; + + const { data: results } = await locals.api.searchApi.search({ params: url.searchParams }); return { user,