1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-28 06:32:44 +01:00

feat(server): CLIP search integration ()

This commit is contained in:
Alex 2023-03-18 08:44:42 -05:00 committed by GitHub
parent 0d436db3ea
commit f56eaae019
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 666 additions and 651 deletions

View file

@ -1,43 +1,58 @@
import os import os
from flask import Flask, request from flask import Flask, request
from transformers import pipeline from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
from PIL import Image
is_dev = os.getenv('NODE_ENV') == 'development'
server_port = os.getenv('MACHINE_LEARNING_PORT', 3003)
server_host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0')
classification_model = os.getenv('MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50')
object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny')
clip_image_model = os.getenv('MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32')
clip_text_model = os.getenv('MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32')
_model_cache = {}
def _get_model(model, task=None):
global _model_cache
key = '|'.join([model, str(task)])
if key not in _model_cache:
if task:
_model_cache[key] = pipeline(model=model, task=task)
else:
_model_cache[key] = SentenceTransformer(model)
return _model_cache[key]
server = Flask(__name__) server = Flask(__name__)
classifier = pipeline(
task="image-classification",
model="microsoft/resnet-50"
)
detector = pipeline(
task="object-detection",
model="hustvl/yolos-tiny"
)
# Environment resolver
is_dev = os.getenv('NODE_ENV') == 'development'
server_port = os.getenv('MACHINE_LEARNING_PORT') or 3003
@server.route("/ping") @server.route("/ping")
def ping(): def ping():
return "pong" return "pong"
@server.route("/object-detection/detect-object", methods=['POST']) @server.route("/object-detection/detect-object", methods=['POST'])
def object_detection(): def object_detection():
model = _get_model(object_model, 'object-detection')
assetPath = request.json['thumbnailPath'] assetPath = request.json['thumbnailPath']
return run_engine(detector, assetPath), 201 return run_engine(model, assetPath), 200
@server.route("/image-classifier/tag-image", methods=['POST']) @server.route("/image-classifier/tag-image", methods=['POST'])
def image_classification(): def image_classification():
model = _get_model(classification_model, 'image-classification')
assetPath = request.json['thumbnailPath'] assetPath = request.json['thumbnailPath']
return run_engine(classifier, assetPath), 201 return run_engine(model, assetPath), 200
@server.route("/sentence-transformer/encode-image", methods=['POST'])
def clip_encode_image():
model = _get_model(clip_image_model)
assetPath = request.json['thumbnailPath']
return model.encode(Image.open(assetPath)).tolist(), 200
@server.route("/sentence-transformer/encode-text", methods=['POST'])
def clip_encode_text():
model = _get_model(clip_text_model)
text = request.json['text']
return model.encode(text).tolist(), 200
def run_engine(engine, path): def run_engine(engine, path):
result = [] result = []
@ -55,4 +70,4 @@ def run_engine(engine, path):
if __name__ == "__main__": if __name__ == "__main__":
server.run(debug=is_dev, host='0.0.0.0', port=server_port) server.run(debug=is_dev, host=server_host, port=server_port)

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -163,7 +163,7 @@ describe('Album service', () => {
expect(result.id).toEqual(albumEntity.id); expect(result.id).toEqual(albumEntity.id);
expect(result.albumName).toEqual(albumEntity.albumName); expect(result.albumName).toEqual(albumEntity.albumName);
expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: albumEntity } }); expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [albumEntity.id] } });
}); });
it('gets list of albums for auth user', async () => { it('gets list of albums for auth user', async () => {
@ -316,7 +316,7 @@ describe('Album service', () => {
albumName: updatedAlbumName, albumName: updatedAlbumName,
albumThumbnailAssetId: updatedAlbumThumbnailAssetId, albumThumbnailAssetId: updatedAlbumThumbnailAssetId,
}); });
expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: updatedAlbum } }); expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [updatedAlbum.id] } });
}); });
it('prevents updating a not owned album (shared with auth user)', async () => { it('prevents updating a not owned album (shared with auth user)', async () => {

View file

@ -59,7 +59,7 @@ export class AlbumService {
async create(authUser: AuthUserDto, createAlbumDto: CreateAlbumDto): Promise<AlbumResponseDto> { async create(authUser: AuthUserDto, createAlbumDto: CreateAlbumDto): Promise<AlbumResponseDto> {
const albumEntity = await this.albumRepository.create(authUser.id, createAlbumDto); const albumEntity = await this.albumRepository.create(authUser.id, createAlbumDto);
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: albumEntity } }); await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [albumEntity.id] } });
return mapAlbum(albumEntity); return mapAlbum(albumEntity);
} }
@ -107,7 +107,7 @@ export class AlbumService {
} }
await this.albumRepository.delete(album); await this.albumRepository.delete(album);
await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ALBUM, data: { id: albumId } }); await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ALBUM, data: { ids: [albumId] } });
} }
async removeUserFromAlbum(authUser: AuthUserDto, albumId: string, userId: string | 'me'): Promise<void> { async removeUserFromAlbum(authUser: AuthUserDto, albumId: string, userId: string | 'me'): Promise<void> {
@ -171,7 +171,7 @@ export class AlbumService {
const updatedAlbum = await this.albumRepository.updateAlbum(album, updateAlbumDto); const updatedAlbum = await this.albumRepository.updateAlbum(album, updateAlbumDto);
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: updatedAlbum } }); await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [updatedAlbum.id] } });
return mapAlbum(updatedAlbum); return mapAlbum(updatedAlbum);
} }

View file

@ -455,8 +455,8 @@ describe('AssetService', () => {
]); ]);
expect(jobMock.queue.mock.calls).toEqual([ expect(jobMock.queue.mock.calls).toEqual([
[{ name: JobName.SEARCH_REMOVE_ASSET, data: { id: 'asset1' } }], [{ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: ['asset1'] } }],
[{ name: JobName.SEARCH_REMOVE_ASSET, data: { id: 'asset2' } }], [{ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: ['asset2'] } }],
[ [
{ {
name: JobName.DELETE_FILES, name: JobName.DELETE_FILES,

View file

@ -170,7 +170,7 @@ export class AssetService {
const updatedAsset = await this._assetRepository.update(authUser.id, asset, dto); const updatedAsset = await this._assetRepository.update(authUser.id, asset, dto);
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { asset: updatedAsset } }); await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [assetId] } });
return mapAsset(updatedAsset); return mapAsset(updatedAsset);
} }
@ -251,8 +251,8 @@ export class AssetService {
res.header('Cache-Control', 'none'); res.header('Cache-Control', 'none');
Logger.error(`Cannot create read stream for asset ${asset.id}`, 'getAssetThumbnail'); Logger.error(`Cannot create read stream for asset ${asset.id}`, 'getAssetThumbnail');
throw new InternalServerErrorException( throw new InternalServerErrorException(
e,
`Cannot read thumbnail file for asset ${asset.id} - contact your administrator`, `Cannot read thumbnail file for asset ${asset.id} - contact your administrator`,
{ cause: e as Error },
); );
} }
} }
@ -427,7 +427,7 @@ export class AssetService {
try { try {
await this._assetRepository.remove(asset); await this._assetRepository.remove(asset);
await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ASSET, data: { id } }); await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: [id] } });
result.push({ id, status: DeleteAssetStatusEnum.SUCCESS }); result.push({ id, status: DeleteAssetStatusEnum.SUCCESS });
deleteQueue.push(asset.originalPath, asset.webpPath, asset.resizePath, asset.encodedVideoPath); deleteQueue.push(asset.originalPath, asset.webpPath, asset.resizePath, asset.encodedVideoPath);

View file

@ -70,6 +70,7 @@ export class JobService {
for (const asset of assets) { for (const asset of assets) {
await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } }); await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } });
await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } }); await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } });
await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } });
} }
return assets.length; return assets.length;
} }

View file

@ -20,7 +20,7 @@ export class SearchController {
@Get() @Get()
async search( async search(
@GetAuthUser() authUser: AuthUserDto, @GetAuthUser() authUser: AuthUserDto,
@Query(new ValidationPipe({ transform: true })) dto: SearchDto, @Query(new ValidationPipe({ transform: true })) dto: SearchDto | any,
): Promise<SearchResponseDto> { ): Promise<SearchResponseDto> {
return this.searchService.search(authUser, dto); return this.searchService.search(authUser, dto);
} }

View file

@ -1,10 +1,9 @@
import { import {
AssetService, AssetService,
IAlbumJob,
IAssetJob, IAssetJob,
IAssetUploadedJob, IAssetUploadedJob,
IBulkEntityJob,
IDeleteFilesJob, IDeleteFilesJob,
IDeleteJob,
IUserDeletionJob, IUserDeletionJob,
JobName, JobName,
MediaService, MediaService,
@ -53,15 +52,20 @@ export class BackgroundTaskProcessor {
export class MachineLearningProcessor { export class MachineLearningProcessor {
constructor(private smartInfoService: SmartInfoService) {} constructor(private smartInfoService: SmartInfoService) {}
@Process({ name: JobName.IMAGE_TAGGING, concurrency: 2 }) @Process({ name: JobName.IMAGE_TAGGING, concurrency: 1 })
async onTagImage(job: Job<IAssetJob>) { async onTagImage(job: Job<IAssetJob>) {
await this.smartInfoService.handleTagImage(job.data); await this.smartInfoService.handleTagImage(job.data);
} }
@Process({ name: JobName.OBJECT_DETECTION, concurrency: 2 }) @Process({ name: JobName.OBJECT_DETECTION, concurrency: 1 })
async onDetectObject(job: Job<IAssetJob>) { async onDetectObject(job: Job<IAssetJob>) {
await this.smartInfoService.handleDetectObjects(job.data); await this.smartInfoService.handleDetectObjects(job.data);
} }
@Process({ name: JobName.ENCODE_CLIP, concurrency: 1 })
async onEncodeClip(job: Job<IAssetJob>) {
await this.smartInfoService.handleEncodeClip(job.data);
}
} }
@Processor(QueueName.SEARCH) @Processor(QueueName.SEARCH)
@ -79,23 +83,23 @@ export class SearchIndexProcessor {
} }
@Process(JobName.SEARCH_INDEX_ALBUM) @Process(JobName.SEARCH_INDEX_ALBUM)
async onIndexAlbum(job: Job<IAlbumJob>) { onIndexAlbum(job: Job<IBulkEntityJob>) {
await this.searchService.handleIndexAlbum(job.data); this.searchService.handleIndexAlbum(job.data);
} }
@Process(JobName.SEARCH_INDEX_ASSET) @Process(JobName.SEARCH_INDEX_ASSET)
async onIndexAsset(job: Job<IAssetJob>) { onIndexAsset(job: Job<IBulkEntityJob>) {
await this.searchService.handleIndexAsset(job.data); this.searchService.handleIndexAsset(job.data);
} }
@Process(JobName.SEARCH_REMOVE_ALBUM) @Process(JobName.SEARCH_REMOVE_ALBUM)
async onRemoveAlbum(job: Job<IDeleteJob>) { onRemoveAlbum(job: Job<IBulkEntityJob>) {
await this.searchService.handleRemoveAlbum(job.data); this.searchService.handleRemoveAlbum(job.data);
} }
@Process(JobName.SEARCH_REMOVE_ASSET) @Process(JobName.SEARCH_REMOVE_ASSET)
async onRemoveAsset(job: Job<IDeleteJob>) { onRemoveAsset(job: Job<IBulkEntityJob>) {
await this.searchService.handleRemoveAsset(job.data); this.searchService.handleRemoveAsset(job.data);
} }
} }

View file

@ -548,116 +548,7 @@
"get": { "get": {
"operationId": "search", "operationId": "search",
"description": "", "description": "",
"parameters": [ "parameters": [],
{
"name": "query",
"required": false,
"in": "query",
"schema": {
"type": "string"
}
},
{
"name": "type",
"required": false,
"in": "query",
"schema": {
"enum": [
"IMAGE",
"VIDEO",
"AUDIO",
"OTHER"
],
"type": "string"
}
},
{
"name": "isFavorite",
"required": false,
"in": "query",
"schema": {
"type": "boolean"
}
},
{
"name": "exifInfo.city",
"required": false,
"in": "query",
"schema": {
"type": "string"
}
},
{
"name": "exifInfo.state",
"required": false,
"in": "query",
"schema": {
"type": "string"
}
},
{
"name": "exifInfo.country",
"required": false,
"in": "query",
"schema": {
"type": "string"
}
},
{
"name": "exifInfo.make",
"required": false,
"in": "query",
"schema": {
"type": "string"
}
},
{
"name": "exifInfo.model",
"required": false,
"in": "query",
"schema": {
"type": "string"
}
},
{
"name": "smartInfo.objects",
"required": false,
"in": "query",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
},
{
"name": "smartInfo.tags",
"required": false,
"in": "query",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
},
{
"name": "recent",
"required": false,
"in": "query",
"schema": {
"type": "boolean"
}
},
{
"name": "motion",
"required": false,
"in": "query",
"schema": {
"type": "boolean"
}
}
],
"responses": { "responses": {
"200": { "200": {
"description": "", "description": "",

View file

@ -3,6 +3,7 @@ import { AlbumEntity } from '@app/infra/db/entities';
export const IAlbumRepository = 'IAlbumRepository'; export const IAlbumRepository = 'IAlbumRepository';
export interface IAlbumRepository { export interface IAlbumRepository {
getByIds(ids: string[]): Promise<AlbumEntity[]>;
deleteAll(userId: string): Promise<void>; deleteAll(userId: string): Promise<void>;
getAll(): Promise<AlbumEntity[]>; getAll(): Promise<AlbumEntity[]>;
save(album: Partial<AlbumEntity>): Promise<AlbumEntity>; save(album: Partial<AlbumEntity>): Promise<AlbumEntity>;

View file

@ -11,7 +11,10 @@ export class AssetCore {
async save(asset: Partial<AssetEntity>) { async save(asset: Partial<AssetEntity>) {
const _asset = await this.assetRepository.save(asset); const _asset = await this.assetRepository.save(asset);
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { asset: _asset } }); await this.jobRepository.queue({
name: JobName.SEARCH_INDEX_ASSET,
data: { ids: [_asset.id] },
});
return _asset; return _asset;
} }

View file

@ -7,6 +7,7 @@ export interface AssetSearchOptions {
export const IAssetRepository = 'IAssetRepository'; export const IAssetRepository = 'IAssetRepository';
export interface IAssetRepository { export interface IAssetRepository {
getByIds(ids: string[]): Promise<AssetEntity[]>;
deleteAll(ownerId: string): Promise<void>; deleteAll(ownerId: string): Promise<void>;
getAll(options?: AssetSearchOptions): Promise<AssetEntity[]>; getAll(options?: AssetSearchOptions): Promise<AssetEntity[]>;
save(asset: Partial<AssetEntity>): Promise<AssetEntity>; save(asset: Partial<AssetEntity>): Promise<AssetEntity>;

View file

@ -54,7 +54,7 @@ describe(AssetService.name, () => {
expect(assetMock.save).toHaveBeenCalledWith(assetEntityStub.image); expect(assetMock.save).toHaveBeenCalledWith(assetEntityStub.image);
expect(jobMock.queue).toHaveBeenCalledWith({ expect(jobMock.queue).toHaveBeenCalledWith({
name: JobName.SEARCH_INDEX_ASSET, name: JobName.SEARCH_INDEX_ASSET,
data: { asset: assetEntityStub.image }, data: { ids: [assetEntityStub.image.id] },
}); });
}); });
}); });

View file

@ -29,4 +29,5 @@ export enum JobName {
SEARCH_INDEX_ALBUM = 'search-index-album', SEARCH_INDEX_ALBUM = 'search-index-album',
SEARCH_REMOVE_ALBUM = 'search-remove-album', SEARCH_REMOVE_ALBUM = 'search-remove-album',
SEARCH_REMOVE_ASSET = 'search-remove-asset', SEARCH_REMOVE_ASSET = 'search-remove-asset',
ENCODE_CLIP = 'clip-encode',
} }

View file

@ -8,15 +8,15 @@ export interface IAssetJob {
asset: AssetEntity; asset: AssetEntity;
} }
export interface IBulkEntityJob {
ids: string[];
}
export interface IAssetUploadedJob { export interface IAssetUploadedJob {
asset: AssetEntity; asset: AssetEntity;
fileName: string; fileName: string;
} }
export interface IDeleteJob {
id: string;
}
export interface IDeleteFilesJob { export interface IDeleteFilesJob {
files: Array<string | null | undefined>; files: Array<string | null | undefined>;
} }

View file

@ -1,10 +1,9 @@
import { JobName, QueueName } from './job.constants'; import { JobName, QueueName } from './job.constants';
import { import {
IAlbumJob,
IAssetJob, IAssetJob,
IAssetUploadedJob, IAssetUploadedJob,
IBulkEntityJob,
IDeleteFilesJob, IDeleteFilesJob,
IDeleteJob,
IReverseGeocodingJob, IReverseGeocodingJob,
IUserDeletionJob, IUserDeletionJob,
} from './job.interface'; } from './job.interface';
@ -31,13 +30,14 @@ export type JobItem =
| { name: JobName.EXTRACT_VIDEO_METADATA; data: IAssetUploadedJob } | { name: JobName.EXTRACT_VIDEO_METADATA; data: IAssetUploadedJob }
| { name: JobName.OBJECT_DETECTION; data: IAssetJob } | { name: JobName.OBJECT_DETECTION; data: IAssetJob }
| { name: JobName.IMAGE_TAGGING; data: IAssetJob } | { name: JobName.IMAGE_TAGGING; data: IAssetJob }
| { name: JobName.ENCODE_CLIP; data: IAssetJob }
| { name: JobName.DELETE_FILES; data: IDeleteFilesJob } | { name: JobName.DELETE_FILES; data: IDeleteFilesJob }
| { name: JobName.SEARCH_INDEX_ASSETS } | { name: JobName.SEARCH_INDEX_ASSETS }
| { name: JobName.SEARCH_INDEX_ASSET; data: IAssetJob } | { name: JobName.SEARCH_INDEX_ASSET; data: IBulkEntityJob }
| { name: JobName.SEARCH_INDEX_ALBUMS } | { name: JobName.SEARCH_INDEX_ALBUMS }
| { name: JobName.SEARCH_INDEX_ALBUM; data: IAlbumJob } | { name: JobName.SEARCH_INDEX_ALBUM; data: IBulkEntityJob }
| { name: JobName.SEARCH_REMOVE_ASSET; data: IDeleteJob } | { name: JobName.SEARCH_REMOVE_ASSET; data: IBulkEntityJob }
| { name: JobName.SEARCH_REMOVE_ALBUM; data: IDeleteJob }; | { name: JobName.SEARCH_REMOVE_ALBUM; data: IBulkEntityJob };
export const IJobRepository = 'IJobRepository'; export const IJobRepository = 'IJobRepository';

View file

@ -54,6 +54,7 @@ export class MediaService {
await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } }); await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } });
await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } }); await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } });
await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } }); await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } });
await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } });
this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset)); this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset));
} }
@ -72,6 +73,7 @@ export class MediaService {
await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } }); await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } });
await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } }); await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } });
await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } }); await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } });
await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } });
this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset)); this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset));
} catch (error: any) { } catch (error: any) {

View file

@ -4,11 +4,21 @@ import { IsArray, IsBoolean, IsEnum, IsNotEmpty, IsOptional, IsString } from 'cl
import { toBoolean } from '../../../../../apps/immich/src/utils/transform.util'; import { toBoolean } from '../../../../../apps/immich/src/utils/transform.util';
export class SearchDto { export class SearchDto {
@IsString()
@IsNotEmpty()
@IsOptional()
q?: string;
@IsString() @IsString()
@IsNotEmpty() @IsNotEmpty()
@IsOptional() @IsOptional()
query?: string; query?: string;
@IsBoolean()
@IsOptional()
@Transform(toBoolean)
clip?: boolean;
@IsEnum(AssetType) @IsEnum(AssetType)
@IsOptional() @IsOptional()
type?: AssetType; type?: AssetType;

View file

@ -5,6 +5,11 @@ export enum SearchCollection {
ALBUMS = 'albums', ALBUMS = 'albums',
} }
export enum SearchStrategy {
CLIP = 'CLIP',
TEXT = 'TEXT',
}
export interface SearchFilter { export interface SearchFilter {
id?: string; id?: string;
userId: string; userId: string;
@ -19,6 +24,7 @@ export interface SearchFilter {
tags?: string[]; tags?: string[];
recent?: boolean; recent?: boolean;
motion?: boolean; motion?: boolean;
debug?: boolean;
} }
export interface SearchResult<T> { export interface SearchResult<T> {
@ -57,16 +63,15 @@ export interface ISearchRepository {
setup(): Promise<void>; setup(): Promise<void>;
checkMigrationStatus(): Promise<SearchCollectionIndexStatus>; checkMigrationStatus(): Promise<SearchCollectionIndexStatus>;
index(collection: SearchCollection.ASSETS, item: AssetEntity): Promise<void>; importAlbums(items: AlbumEntity[], done: boolean): Promise<void>;
index(collection: SearchCollection.ALBUMS, item: AlbumEntity): Promise<void>; importAssets(items: AssetEntity[], done: boolean): Promise<void>;
delete(collection: SearchCollection, id: string): Promise<void>; deleteAlbums(ids: string[]): Promise<void>;
deleteAssets(ids: string[]): Promise<void>;
import(collection: SearchCollection.ASSETS, items: AssetEntity[], done: boolean): Promise<void>; searchAlbums(query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>;
import(collection: SearchCollection.ALBUMS, items: AlbumEntity[], done: boolean): Promise<void>; searchAssets(query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>;
vectorSearch(query: number[], filters: SearchFilter): Promise<SearchResult<AssetEntity>>;
search(collection: SearchCollection.ASSETS, query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>;
search(collection: SearchCollection.ALBUMS, query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>;
explore(userId: string): Promise<SearchExploreItem<AssetEntity>[]>; explore(userId: string): Promise<SearchExploreItem<AssetEntity>[]>;
} }

View file

@ -4,25 +4,32 @@ import { plainToInstance } from 'class-transformer';
import { import {
albumStub, albumStub,
assetEntityStub, assetEntityStub,
asyncTick,
authStub, authStub,
newAlbumRepositoryMock, newAlbumRepositoryMock,
newAssetRepositoryMock, newAssetRepositoryMock,
newJobRepositoryMock, newJobRepositoryMock,
newMachineLearningRepositoryMock,
newSearchRepositoryMock, newSearchRepositoryMock,
searchStub,
} from '../../test'; } from '../../test';
import { IAlbumRepository } from '../album/album.repository'; import { IAlbumRepository } from '../album/album.repository';
import { IAssetRepository } from '../asset/asset.repository'; import { IAssetRepository } from '../asset/asset.repository';
import { JobName } from '../job'; import { JobName } from '../job';
import { IJobRepository } from '../job/job.repository'; import { IJobRepository } from '../job/job.repository';
import { IMachineLearningRepository } from '../smart-info';
import { SearchDto } from './dto'; import { SearchDto } from './dto';
import { ISearchRepository } from './search.repository'; import { ISearchRepository } from './search.repository';
import { SearchService } from './search.service'; import { SearchService } from './search.service';
jest.useFakeTimers();
describe(SearchService.name, () => { describe(SearchService.name, () => {
let sut: SearchService; let sut: SearchService;
let albumMock: jest.Mocked<IAlbumRepository>; let albumMock: jest.Mocked<IAlbumRepository>;
let assetMock: jest.Mocked<IAssetRepository>; let assetMock: jest.Mocked<IAssetRepository>;
let jobMock: jest.Mocked<IJobRepository>; let jobMock: jest.Mocked<IJobRepository>;
let machineMock: jest.Mocked<IMachineLearningRepository>;
let searchMock: jest.Mocked<ISearchRepository>; let searchMock: jest.Mocked<ISearchRepository>;
let configMock: jest.Mocked<ConfigService>; let configMock: jest.Mocked<ConfigService>;
@ -30,10 +37,15 @@ describe(SearchService.name, () => {
albumMock = newAlbumRepositoryMock(); albumMock = newAlbumRepositoryMock();
assetMock = newAssetRepositoryMock(); assetMock = newAssetRepositoryMock();
jobMock = newJobRepositoryMock(); jobMock = newJobRepositoryMock();
machineMock = newMachineLearningRepositoryMock();
searchMock = newSearchRepositoryMock(); searchMock = newSearchRepositoryMock();
configMock = { get: jest.fn() } as unknown as jest.Mocked<ConfigService>; configMock = { get: jest.fn() } as unknown as jest.Mocked<ConfigService>;
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
});
afterEach(() => {
sut.teardown();
}); });
it('should work', () => { it('should work', () => {
@ -69,7 +81,7 @@ describe(SearchService.name, () => {
it('should be disabled via an env variable', () => { it('should be disabled via an env variable', () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
expect(sut.isEnabled()).toBe(false); expect(sut.isEnabled()).toBe(false);
}); });
@ -82,7 +94,7 @@ describe(SearchService.name, () => {
it('should return the config when search is disabled', () => { it('should return the config when search is disabled', () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
expect(sut.getConfig()).toEqual({ enabled: false }); expect(sut.getConfig()).toEqual({ enabled: false });
}); });
@ -91,13 +103,15 @@ describe(SearchService.name, () => {
describe(`bootstrap`, () => { describe(`bootstrap`, () => {
it('should skip when search is disabled', async () => { it('should skip when search is disabled', async () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
await sut.bootstrap(); await sut.bootstrap();
expect(searchMock.setup).not.toHaveBeenCalled(); expect(searchMock.setup).not.toHaveBeenCalled();
expect(searchMock.checkMigrationStatus).not.toHaveBeenCalled(); expect(searchMock.checkMigrationStatus).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled(); expect(jobMock.queue).not.toHaveBeenCalled();
sut.teardown();
}); });
it('should skip schema migration if not needed', async () => { it('should skip schema migration if not needed', async () => {
@ -123,21 +137,18 @@ describe(SearchService.name, () => {
describe('search', () => { describe('search', () => {
it('should throw an error is search is disabled', async () => { it('should throw an error is search is disabled', async () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
await expect(sut.search(authStub.admin, {})).rejects.toBeInstanceOf(BadRequestException); await expect(sut.search(authStub.admin, {})).rejects.toBeInstanceOf(BadRequestException);
expect(searchMock.search).not.toHaveBeenCalled(); expect(searchMock.searchAlbums).not.toHaveBeenCalled();
expect(searchMock.searchAssets).not.toHaveBeenCalled();
}); });
it('should search assets and albums', async () => { it('should search assets and albums', async () => {
searchMock.search.mockResolvedValue({ searchMock.searchAssets.mockResolvedValue(searchStub.emptyResults);
total: 0, searchMock.searchAlbums.mockResolvedValue(searchStub.emptyResults);
count: 0, searchMock.vectorSearch.mockResolvedValue(searchStub.emptyResults);
page: 1,
items: [],
facets: [],
});
await expect(sut.search(authStub.admin, {})).resolves.toEqual({ await expect(sut.search(authStub.admin, {})).resolves.toEqual({
albums: { albums: {
@ -156,162 +167,158 @@ describe(SearchService.name, () => {
}, },
}); });
expect(searchMock.search.mock.calls).toEqual([ // expect(searchMock.searchAssets).toHaveBeenCalledWith('*', { userId: authStub.admin.id });
['assets', '*', { userId: authStub.admin.id }], expect(searchMock.searchAlbums).toHaveBeenCalledWith('*', { userId: authStub.admin.id });
['albums', '*', { userId: authStub.admin.id }],
]);
}); });
}); });
describe('handleIndexAssets', () => { describe('handleIndexAssets', () => {
it('should skip if search is disabled', async () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
await sut.handleIndexAssets();
expect(searchMock.import).not.toHaveBeenCalled();
});
it('should index all the assets', async () => { it('should index all the assets', async () => {
assetMock.getAll.mockResolvedValue([]); assetMock.getAll.mockResolvedValue([assetEntityStub.image]);
await sut.handleIndexAssets(); await sut.handleIndexAssets();
expect(searchMock.import).toHaveBeenCalledWith('assets', [], true); expect(searchMock.importAssets).toHaveBeenCalledWith([assetEntityStub.image], true);
}); });
it('should log an error', async () => { it('should log an error', async () => {
assetMock.getAll.mockResolvedValue([]); assetMock.getAll.mockResolvedValue([assetEntityStub.image]);
searchMock.import.mockRejectedValue(new Error('import failed')); searchMock.importAssets.mockRejectedValue(new Error('import failed'));
await sut.handleIndexAssets(); await sut.handleIndexAssets();
expect(searchMock.importAssets).toHaveBeenCalled();
});
it('should skip if search is disabled', async () => {
configMock.get.mockReturnValue('false');
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
await sut.handleIndexAssets();
expect(searchMock.importAssets).not.toHaveBeenCalled();
expect(searchMock.importAlbums).not.toHaveBeenCalled();
}); });
}); });
describe('handleIndexAsset', () => { describe('handleIndexAsset', () => {
it('should skip if search is disabled', async () => { it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleIndexAsset({ ids: [assetEntityStub.image.id] });
await sut.handleIndexAsset({ asset: assetEntityStub.image });
expect(searchMock.index).not.toHaveBeenCalled();
}); });
it('should index the asset', async () => { it('should index the asset', () => {
await sut.handleIndexAsset({ asset: assetEntityStub.image }); sut.handleIndexAsset({ ids: [assetEntityStub.image.id] });
expect(searchMock.index).toHaveBeenCalledWith('assets', assetEntityStub.image);
});
it('should log an error', async () => {
searchMock.index.mockRejectedValue(new Error('index failed'));
await sut.handleIndexAsset({ asset: assetEntityStub.image });
expect(searchMock.index).toHaveBeenCalled();
}); });
}); });
describe('handleIndexAlbums', () => { describe('handleIndexAlbums', () => {
it('should skip if search is disabled', async () => { it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleIndexAlbums();
await sut.handleIndexAlbums();
expect(searchMock.import).not.toHaveBeenCalled();
}); });
it('should index all the albums', async () => { it('should index all the albums', async () => {
albumMock.getAll.mockResolvedValue([]); albumMock.getAll.mockResolvedValue([albumStub.empty]);
await sut.handleIndexAlbums(); await sut.handleIndexAlbums();
expect(searchMock.import).toHaveBeenCalledWith('albums', [], true); expect(searchMock.importAlbums).toHaveBeenCalledWith([albumStub.empty], true);
}); });
it('should log an error', async () => { it('should log an error', async () => {
albumMock.getAll.mockResolvedValue([]); albumMock.getAll.mockResolvedValue([albumStub.empty]);
searchMock.import.mockRejectedValue(new Error('import failed')); searchMock.importAlbums.mockRejectedValue(new Error('import failed'));
await sut.handleIndexAlbums(); await sut.handleIndexAlbums();
expect(searchMock.importAlbums).toHaveBeenCalled();
}); });
}); });
describe('handleIndexAlbum', () => { describe('handleIndexAlbum', () => {
it('should skip if search is disabled', async () => { it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleIndexAlbum({ ids: [albumStub.empty.id] });
await sut.handleIndexAlbum({ album: albumStub.empty });
expect(searchMock.index).not.toHaveBeenCalled();
}); });
it('should index the album', async () => { it('should index the album', () => {
await sut.handleIndexAlbum({ album: albumStub.empty }); sut.handleIndexAlbum({ ids: [albumStub.empty.id] });
expect(searchMock.index).toHaveBeenCalledWith('albums', albumStub.empty);
});
it('should log an error', async () => {
searchMock.index.mockRejectedValue(new Error('index failed'));
await sut.handleIndexAlbum({ album: albumStub.empty });
expect(searchMock.index).toHaveBeenCalled();
}); });
}); });
describe('handleRemoveAlbum', () => { describe('handleRemoveAlbum', () => {
it('should skip if search is disabled', async () => { it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleRemoveAlbum({ ids: ['album1'] });
await sut.handleRemoveAlbum({ id: 'album1' });
expect(searchMock.delete).not.toHaveBeenCalled();
}); });
it('should remove the album', async () => { it('should remove the album', () => {
await sut.handleRemoveAlbum({ id: 'album1' }); sut.handleRemoveAlbum({ ids: ['album1'] });
expect(searchMock.delete).toHaveBeenCalledWith('albums', 'album1');
});
it('should log an error', async () => {
searchMock.delete.mockRejectedValue(new Error('remove failed'));
await sut.handleRemoveAlbum({ id: 'album1' });
expect(searchMock.delete).toHaveBeenCalled();
}); });
}); });
describe('handleRemoveAsset', () => { describe('handleRemoveAsset', () => {
it('should skip if search is disabled', async () => { it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false'); configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock); const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleRemoveAsset({ ids: ['asset1'] });
await sut.handleRemoveAsset({ id: 'asset1`' });
expect(searchMock.delete).not.toHaveBeenCalled();
}); });
it('should remove the asset', async () => { it('should remove the asset', () => {
await sut.handleRemoveAsset({ id: 'asset1' }); sut.handleRemoveAsset({ ids: ['asset1'] });
});
});
expect(searchMock.delete).toHaveBeenCalledWith('assets', 'asset1'); describe('flush', () => {
it('should flush queued album updates', async () => {
albumMock.getByIds.mockResolvedValue([albumStub.empty]);
sut.handleIndexAlbum({ ids: ['album1'] });
jest.runOnlyPendingTimers();
await asyncTick(4);
expect(albumMock.getByIds).toHaveBeenCalledWith(['album1']);
expect(searchMock.importAlbums).toHaveBeenCalledWith([albumStub.empty], false);
}); });
it('should log an error', async () => { it('should flush queued album deletes', async () => {
searchMock.delete.mockRejectedValue(new Error('remove failed')); sut.handleRemoveAlbum({ ids: ['album1'] });
await sut.handleRemoveAsset({ id: 'asset1' }); jest.runOnlyPendingTimers();
expect(searchMock.delete).toHaveBeenCalled(); await asyncTick(4);
expect(searchMock.deleteAlbums).toHaveBeenCalledWith(['album1']);
});
it('should flush queued asset updates', async () => {
assetMock.getByIds.mockResolvedValue([assetEntityStub.image]);
sut.handleIndexAsset({ ids: ['asset1'] });
jest.runOnlyPendingTimers();
await asyncTick(4);
expect(assetMock.getByIds).toHaveBeenCalledWith(['asset1']);
expect(searchMock.importAssets).toHaveBeenCalledWith([assetEntityStub.image], false);
});
it('should flush queued asset deletes', async () => {
sut.handleRemoveAsset({ ids: ['asset1'] });
jest.runOnlyPendingTimers();
await asyncTick(4);
expect(searchMock.deleteAssets).toHaveBeenCalledWith(['asset1']);
}); });
}); });
}); });

View file

@ -1,27 +1,64 @@
import { AssetEntity } from '@app/infra/db/entities'; import { MACHINE_LEARNING_ENABLED } from '@app/common';
import { AlbumEntity, AssetEntity } from '@app/infra/db/entities';
import { BadRequestException, Inject, Injectable, Logger } from '@nestjs/common'; import { BadRequestException, Inject, Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config'; import { ConfigService } from '@nestjs/config';
import { mapAlbum } from '../album';
import { IAlbumRepository } from '../album/album.repository'; import { IAlbumRepository } from '../album/album.repository';
import { mapAsset } from '../asset';
import { IAssetRepository } from '../asset/asset.repository'; import { IAssetRepository } from '../asset/asset.repository';
import { AuthUserDto } from '../auth'; import { AuthUserDto } from '../auth';
import { IAlbumJob, IAssetJob, IDeleteJob, IJobRepository, JobName } from '../job'; import { IBulkEntityJob, IJobRepository, JobName } from '../job';
import { IMachineLearningRepository } from '../smart-info';
import { SearchDto } from './dto'; import { SearchDto } from './dto';
import { SearchConfigResponseDto, SearchResponseDto } from './response-dto'; import { SearchConfigResponseDto, SearchResponseDto } from './response-dto';
import { ISearchRepository, SearchCollection, SearchExploreItem } from './search.repository'; import {
ISearchRepository,
SearchCollection,
SearchExploreItem,
SearchResult,
SearchStrategy,
} from './search.repository';
interface SyncQueue {
upsert: Set<string>;
delete: Set<string>;
}
@Injectable() @Injectable()
export class SearchService { export class SearchService {
private logger = new Logger(SearchService.name); private logger = new Logger(SearchService.name);
private enabled: boolean; private enabled: boolean;
private timer: NodeJS.Timer | null = null;
private albumQueue: SyncQueue = {
upsert: new Set(),
delete: new Set(),
};
private assetQueue: SyncQueue = {
upsert: new Set(),
delete: new Set(),
};
constructor( constructor(
@Inject(IAlbumRepository) private albumRepository: IAlbumRepository, @Inject(IAlbumRepository) private albumRepository: IAlbumRepository,
@Inject(IAssetRepository) private assetRepository: IAssetRepository, @Inject(IAssetRepository) private assetRepository: IAssetRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository, @Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository,
@Inject(ISearchRepository) private searchRepository: ISearchRepository, @Inject(ISearchRepository) private searchRepository: ISearchRepository,
configService: ConfigService, configService: ConfigService,
) { ) {
this.enabled = configService.get('TYPESENSE_ENABLED') !== 'false'; this.enabled = configService.get('TYPESENSE_ENABLED') !== 'false';
if (this.enabled) {
this.timer = setInterval(() => this.flush(), 5_000);
}
}
teardown() {
if (this.timer) {
clearInterval(this.timer);
this.timer = null;
}
} }
isEnabled() { isEnabled() {
@ -61,103 +98,131 @@ export class SearchService {
async search(authUser: AuthUserDto, dto: SearchDto): Promise<SearchResponseDto> { async search(authUser: AuthUserDto, dto: SearchDto): Promise<SearchResponseDto> {
this.assertEnabled(); this.assertEnabled();
const query = dto.query || '*'; const query = dto.q || dto.query || '*';
const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
const filters = { userId: authUser.id, ...dto };
let assets: SearchResult<AssetEntity>;
switch (strategy) {
case SearchStrategy.TEXT:
assets = await this.searchRepository.searchAssets(query, filters);
break;
case SearchStrategy.CLIP:
default:
if (!MACHINE_LEARNING_ENABLED) {
throw new BadRequestException('Machine Learning is disabled');
}
const clip = await this.machineLearning.encodeText(query);
assets = await this.searchRepository.vectorSearch(clip, filters);
}
const albums = await this.searchRepository.searchAlbums(query, filters);
return { return {
assets: (await this.searchRepository.search(SearchCollection.ASSETS, query, { albums: { ...albums, items: albums.items.map(mapAlbum) },
userId: authUser.id, assets: { ...assets, items: assets.items.map(mapAsset) },
...dto,
})) as any,
albums: (await this.searchRepository.search(SearchCollection.ALBUMS, query, {
userId: authUser.id,
...dto,
})) as any,
}; };
} }
async handleIndexAssets() {
if (!this.enabled) {
return;
}
try {
this.logger.debug(`Running indexAssets`);
// TODO: do this in batches based on searchIndexVersion
const assets = await this.assetRepository.getAll({ isVisible: true });
this.logger.log(`Indexing ${assets.length} assets`);
await this.searchRepository.import(SearchCollection.ASSETS, assets, true);
this.logger.debug('Finished re-indexing all assets');
} catch (error: any) {
this.logger.error(`Unable to index all assets`, error?.stack);
}
}
async handleIndexAsset(data: IAssetJob) {
if (!this.enabled) {
return;
}
const { asset } = data;
if (!asset.isVisible) {
return;
}
try {
await this.searchRepository.index(SearchCollection.ASSETS, asset);
} catch (error: any) {
this.logger.error(`Unable to index asset: ${asset.id}`, error?.stack);
}
}
async handleIndexAlbums() { async handleIndexAlbums() {
if (!this.enabled) { if (!this.enabled) {
return; return;
} }
try { try {
const albums = await this.albumRepository.getAll(); const albums = this.patchAlbums(await this.albumRepository.getAll());
this.logger.log(`Indexing ${albums.length} albums`); this.logger.log(`Indexing ${albums.length} albums`);
await this.searchRepository.import(SearchCollection.ALBUMS, albums, true); await this.searchRepository.importAlbums(albums, true);
this.logger.debug('Finished re-indexing all albums');
} catch (error: any) { } catch (error: any) {
this.logger.error(`Unable to index all albums`, error?.stack); this.logger.error(`Unable to index all albums`, error?.stack);
} }
} }
async handleIndexAlbum(data: IAlbumJob) { async handleIndexAssets() {
if (!this.enabled) { if (!this.enabled) {
return; return;
} }
const { album } = data;
try { try {
await this.searchRepository.index(SearchCollection.ALBUMS, album); // TODO: do this in batches based on searchIndexVersion
const assets = this.patchAssets(await this.assetRepository.getAll({ isVisible: true }));
this.logger.log(`Indexing ${assets.length} assets`);
await this.searchRepository.importAssets(assets, true);
this.logger.debug('Finished re-indexing all assets');
} catch (error: any) { } catch (error: any) {
this.logger.error(`Unable to index album: ${album.id}`, error?.stack); this.logger.error(`Unable to index all assets`, error?.stack);
} }
} }
async handleRemoveAlbum(data: IDeleteJob) { handleIndexAlbum({ ids }: IBulkEntityJob) {
await this.handleRemove(SearchCollection.ALBUMS, data);
}
async handleRemoveAsset(data: IDeleteJob) {
await this.handleRemove(SearchCollection.ASSETS, data);
}
private async handleRemove(collection: SearchCollection, data: IDeleteJob) {
if (!this.enabled) { if (!this.enabled) {
return; return;
} }
const { id } = data; for (const id of ids) {
this.albumQueue.upsert.add(id);
}
}
try { handleIndexAsset({ ids }: IBulkEntityJob) {
await this.searchRepository.delete(collection, id); if (!this.enabled) {
} catch (error: any) { return;
this.logger.error(`Unable to remove ${collection}: ${id}`, error?.stack); }
for (const id of ids) {
this.assetQueue.upsert.add(id);
}
}
handleRemoveAlbum({ ids }: IBulkEntityJob) {
if (!this.enabled) {
return;
}
for (const id of ids) {
this.albumQueue.delete.add(id);
}
}
handleRemoveAsset({ ids }: IBulkEntityJob) {
if (!this.enabled) {
return;
}
for (const id of ids) {
this.assetQueue.delete.add(id);
}
}
private async flush() {
if (this.albumQueue.upsert.size > 0) {
const ids = [...this.albumQueue.upsert.keys()];
const items = await this.idsToAlbums(ids);
this.logger.debug(`Flushing ${items.length} album upserts`);
await this.searchRepository.importAlbums(items, false);
this.albumQueue.upsert.clear();
}
if (this.albumQueue.delete.size > 0) {
const ids = [...this.albumQueue.delete.keys()];
this.logger.debug(`Flushing ${ids.length} album deletes`);
await this.searchRepository.deleteAlbums(ids);
this.albumQueue.delete.clear();
}
if (this.assetQueue.upsert.size > 0) {
const ids = [...this.assetQueue.upsert.keys()];
const items = await this.idsToAssets(ids);
this.logger.debug(`Flushing ${items.length} asset upserts`);
await this.searchRepository.importAssets(items, false);
this.assetQueue.upsert.clear();
}
if (this.assetQueue.delete.size > 0) {
const ids = [...this.assetQueue.delete.keys()];
this.logger.debug(`Flushing ${ids.length} asset deletes`);
await this.searchRepository.deleteAssets(ids);
this.assetQueue.delete.clear();
} }
} }
@ -166,4 +231,22 @@ export class SearchService {
throw new BadRequestException('Search is disabled'); throw new BadRequestException('Search is disabled');
} }
} }
private async idsToAlbums(ids: string[]): Promise<AlbumEntity[]> {
const entities = await this.albumRepository.getByIds(ids);
return this.patchAlbums(entities);
}
private async idsToAssets(ids: string[]): Promise<AssetEntity[]> {
const entities = await this.assetRepository.getByIds(ids);
return this.patchAssets(entities.filter((entity) => entity.isVisible));
}
private patchAssets(assets: AssetEntity[]): AssetEntity[] {
return assets;
}
private patchAlbums(albums: AlbumEntity[]): AlbumEntity[] {
return albums.map((entity) => ({ ...entity, assets: [] }));
}
} }

View file

@ -7,4 +7,6 @@ export interface MachineLearningInput {
export interface IMachineLearningRepository { export interface IMachineLearningRepository {
tagImage(input: MachineLearningInput): Promise<string[]>; tagImage(input: MachineLearningInput): Promise<string[]>;
detectObjects(input: MachineLearningInput): Promise<string[]>; detectObjects(input: MachineLearningInput): Promise<string[]>;
encodeImage(input: MachineLearningInput): Promise<number[]>;
encodeText(input: string): Promise<number[]>;
} }

View file

@ -1,5 +1,6 @@
import { AssetEntity } from '@app/infra/db/entities'; import { AssetEntity } from '@app/infra/db/entities';
import { newMachineLearningRepositoryMock, newSmartInfoRepositoryMock } from '../../test'; import { newJobRepositoryMock, newMachineLearningRepositoryMock, newSmartInfoRepositoryMock } from '../../test';
import { IJobRepository } from '../job';
import { IMachineLearningRepository } from './machine-learning.interface'; import { IMachineLearningRepository } from './machine-learning.interface';
import { ISmartInfoRepository } from './smart-info.repository'; import { ISmartInfoRepository } from './smart-info.repository';
import { SmartInfoService } from './smart-info.service'; import { SmartInfoService } from './smart-info.service';
@ -11,13 +12,15 @@ const asset = {
describe(SmartInfoService.name, () => { describe(SmartInfoService.name, () => {
let sut: SmartInfoService; let sut: SmartInfoService;
let jobMock: jest.Mocked<IJobRepository>;
let smartMock: jest.Mocked<ISmartInfoRepository>; let smartMock: jest.Mocked<ISmartInfoRepository>;
let machineMock: jest.Mocked<IMachineLearningRepository>; let machineMock: jest.Mocked<IMachineLearningRepository>;
beforeEach(async () => { beforeEach(async () => {
smartMock = newSmartInfoRepositoryMock(); smartMock = newSmartInfoRepositoryMock();
jobMock = newJobRepositoryMock();
machineMock = newMachineLearningRepositoryMock(); machineMock = newMachineLearningRepositoryMock();
sut = new SmartInfoService(smartMock, machineMock); sut = new SmartInfoService(jobMock, smartMock, machineMock);
}); });
it('should work', () => { it('should work', () => {

View file

@ -1,6 +1,6 @@
import { MACHINE_LEARNING_ENABLED } from '@app/common'; import { MACHINE_LEARNING_ENABLED } from '@app/common';
import { Inject, Injectable, Logger } from '@nestjs/common'; import { Inject, Injectable, Logger } from '@nestjs/common';
import { IAssetJob } from '../job'; import { IAssetJob, IJobRepository, JobName } from '../job';
import { IMachineLearningRepository } from './machine-learning.interface'; import { IMachineLearningRepository } from './machine-learning.interface';
import { ISmartInfoRepository } from './smart-info.repository'; import { ISmartInfoRepository } from './smart-info.repository';
@ -9,6 +9,7 @@ export class SmartInfoService {
private logger = new Logger(SmartInfoService.name); private logger = new Logger(SmartInfoService.name);
constructor( constructor(
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(ISmartInfoRepository) private repository: ISmartInfoRepository, @Inject(ISmartInfoRepository) private repository: ISmartInfoRepository,
@Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository, @Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository,
) {} ) {}
@ -24,6 +25,7 @@ export class SmartInfoService {
const tags = await this.machineLearning.tagImage({ thumbnailPath: asset.resizePath }); const tags = await this.machineLearning.tagImage({ thumbnailPath: asset.resizePath });
if (tags.length > 0) { if (tags.length > 0) {
await this.repository.upsert({ assetId: asset.id, tags }); await this.repository.upsert({ assetId: asset.id, tags });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } });
} }
} catch (error: any) { } catch (error: any) {
this.logger.error(`Unable to run image tagging pipeline: ${asset.id}`, error?.stack); this.logger.error(`Unable to run image tagging pipeline: ${asset.id}`, error?.stack);
@ -41,9 +43,26 @@ export class SmartInfoService {
const objects = await this.machineLearning.detectObjects({ thumbnailPath: asset.resizePath }); const objects = await this.machineLearning.detectObjects({ thumbnailPath: asset.resizePath });
if (objects.length > 0) { if (objects.length > 0) {
await this.repository.upsert({ assetId: asset.id, objects }); await this.repository.upsert({ assetId: asset.id, objects });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } });
} }
} catch (error: any) { } catch (error: any) {
this.logger.error(`Unable run object detection pipeline: ${asset.id}`, error?.stack); this.logger.error(`Unable run object detection pipeline: ${asset.id}`, error?.stack);
} }
} }
async handleEncodeClip(data: IAssetJob) {
const { asset } = data;
if (!MACHINE_LEARNING_ENABLED || !asset.resizePath) {
return;
}
try {
const clipEmbedding = await this.machineLearning.encodeImage({ thumbnailPath: asset.resizePath });
await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } });
} catch (error: any) {
this.logger.error(`Unable run clip encoding pipeline: ${asset.id}`, error?.stack);
}
}
} }

View file

@ -2,6 +2,7 @@ import { IAlbumRepository } from '../src';
export const newAlbumRepositoryMock = (): jest.Mocked<IAlbumRepository> => { export const newAlbumRepositoryMock = (): jest.Mocked<IAlbumRepository> => {
return { return {
getByIds: jest.fn(),
deleteAll: jest.fn(), deleteAll: jest.fn(),
getAll: jest.fn(), getAll: jest.fn(),
save: jest.fn(), save: jest.fn(),

View file

@ -2,6 +2,7 @@ import { IAssetRepository } from '../src';
export const newAssetRepositoryMock = (): jest.Mocked<IAssetRepository> => { export const newAssetRepositoryMock = (): jest.Mocked<IAssetRepository> => {
return { return {
getByIds: jest.fn(),
getAll: jest.fn(), getAll: jest.fn(),
deleteAll: jest.fn(), deleteAll: jest.fn(),
save: jest.fn(), save: jest.fn(),

View file

@ -15,6 +15,7 @@ import {
AuthUserDto, AuthUserDto,
ExifResponseDto, ExifResponseDto,
mapUser, mapUser,
SearchResult,
SharedLinkResponseDto, SharedLinkResponseDto,
} from '../src'; } from '../src';
@ -448,6 +449,7 @@ export const sharedLinkStub = {
tags: [], tags: [],
objects: ['a', 'b', 'c'], objects: ['a', 'b', 'c'],
asset: null as any, asset: null as any,
clipEmbedding: [0.12, 0.13, 0.14],
}, },
webpPath: '', webpPath: '',
encodedVideoPath: '', encodedVideoPath: '',
@ -550,3 +552,13 @@ export const sharedLinkResponseStub = {
// TODO - the constructor isn't used anywhere, so not test coverage // TODO - the constructor isn't used anywhere, so not test coverage
new ExifResponseDto(); new ExifResponseDto();
export const searchStub = {
emptyResults: Object.freeze<SearchResult<any>>({
total: 0,
count: 0,
page: 1,
items: [],
facets: [],
}),
};

View file

@ -13,3 +13,9 @@ export * from './storage.repository.mock';
export * from './system-config.repository.mock'; export * from './system-config.repository.mock';
export * from './user-token.repository.mock'; export * from './user-token.repository.mock';
export * from './user.repository.mock'; export * from './user.repository.mock';
export async function asyncTick(steps: number) {
for (let i = 0; i < steps; i++) {
await Promise.resolve();
}
}

View file

@ -4,5 +4,7 @@ export const newMachineLearningRepositoryMock = (): jest.Mocked<IMachineLearning
return { return {
tagImage: jest.fn(), tagImage: jest.fn(),
detectObjects: jest.fn(), detectObjects: jest.fn(),
encodeImage: jest.fn(),
encodeText: jest.fn(),
}; };
}; };

View file

@ -4,10 +4,13 @@ export const newSearchRepositoryMock = (): jest.Mocked<ISearchRepository> => {
return { return {
setup: jest.fn(), setup: jest.fn(),
checkMigrationStatus: jest.fn(), checkMigrationStatus: jest.fn(),
index: jest.fn(), importAssets: jest.fn(),
import: jest.fn(), importAlbums: jest.fn(),
search: jest.fn(), deleteAlbums: jest.fn(),
delete: jest.fn(), deleteAssets: jest.fn(),
searchAssets: jest.fn(),
searchAlbums: jest.fn(),
vectorSearch: jest.fn(),
explore: jest.fn(), explore: jest.fn(),
}; };
}; };

View file

@ -15,4 +15,14 @@ export class SmartInfoEntity {
@Column({ type: 'text', array: true, nullable: true }) @Column({ type: 'text', array: true, nullable: true })
objects!: string[] | null; objects!: string[] | null;
@Column({
type: 'numeric',
array: true,
nullable: true,
// note: migration generator is broken for numeric[], but these _are_ set in the database
// precision: 20,
// scale: 19,
})
clipEmbedding!: number[] | null;
} }

View file

@ -0,0 +1,13 @@
import { MigrationInterface, QueryRunner } from 'typeorm';
export class AddCLIPEncodeDataColumn1677971458822 implements MigrationInterface {
name = 'AddCLIPEncodeDataColumn1677971458822';
public async up(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`ALTER TABLE "smart_info" ADD "clipEmbedding" numeric(20,19) array`);
}
public async down(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`ALTER TABLE "smart_info" DROP COLUMN "clipEmbedding"`);
}
}

View file

@ -1,19 +1,34 @@
import { IAlbumRepository } from '@app/domain'; import { IAlbumRepository } from '@app/domain';
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm'; import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm'; import { In, Repository } from 'typeorm';
import { AlbumEntity } from '../entities'; import { AlbumEntity } from '../entities';
@Injectable() @Injectable()
export class AlbumRepository implements IAlbumRepository { export class AlbumRepository implements IAlbumRepository {
constructor(@InjectRepository(AlbumEntity) private repository: Repository<AlbumEntity>) {} constructor(@InjectRepository(AlbumEntity) private repository: Repository<AlbumEntity>) {}
getByIds(ids: string[]): Promise<AlbumEntity[]> {
return this.repository.find({
where: {
id: In(ids),
},
relations: {
owner: true,
},
});
}
async deleteAll(userId: string): Promise<void> { async deleteAll(userId: string): Promise<void> {
await this.repository.delete({ ownerId: userId }); await this.repository.delete({ ownerId: userId });
} }
getAll(): Promise<AlbumEntity[]> { getAll(): Promise<AlbumEntity[]> {
return this.repository.find(); return this.repository.find({
relations: {
owner: true,
},
});
} }
async save(album: Partial<AlbumEntity>) { async save(album: Partial<AlbumEntity>) {

View file

@ -1,13 +1,24 @@
import { AssetSearchOptions, IAssetRepository } from '@app/domain'; import { AssetSearchOptions, IAssetRepository } from '@app/domain';
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm'; import { InjectRepository } from '@nestjs/typeorm';
import { Not, Repository } from 'typeorm'; import { In, Not, Repository } from 'typeorm';
import { AssetEntity, AssetType } from '../entities'; import { AssetEntity, AssetType } from '../entities';
@Injectable() @Injectable()
export class AssetRepository implements IAssetRepository { export class AssetRepository implements IAssetRepository {
constructor(@InjectRepository(AssetEntity) private repository: Repository<AssetEntity>) {} constructor(@InjectRepository(AssetEntity) private repository: Repository<AssetEntity>) {}
getByIds(ids: string[]): Promise<AssetEntity[]> {
return this.repository.find({
where: { id: In(ids) },
relations: {
exifInfo: true,
smartInfo: true,
tags: true,
},
});
}
async deleteAll(ownerId: string): Promise<void> { async deleteAll(ownerId: string): Promise<void> {
await this.repository.delete({ ownerId }); await this.repository.delete({ ownerId });
} }

View file

@ -41,6 +41,7 @@ export class JobRepository implements IJobRepository {
case JobName.OBJECT_DETECTION: case JobName.OBJECT_DETECTION:
case JobName.IMAGE_TAGGING: case JobName.IMAGE_TAGGING:
case JobName.ENCODE_CLIP:
await this.machineLearning.add(item.name, item.data); await this.machineLearning.add(item.name, item.data);
break; break;
@ -73,7 +74,7 @@ export class JobRepository implements IJobRepository {
case JobName.SEARCH_INDEX_ASSETS: case JobName.SEARCH_INDEX_ASSETS:
case JobName.SEARCH_INDEX_ALBUMS: case JobName.SEARCH_INDEX_ALBUMS:
await this.searchIndex.add(item.name); await this.searchIndex.add(item.name, {});
break; break;
case JobName.SEARCH_INDEX_ASSET: case JobName.SEARCH_INDEX_ASSET:

View file

@ -14,4 +14,12 @@ export class MachineLearningRepository implements IMachineLearningRepository {
detectObjects(input: MachineLearningInput): Promise<string[]> { detectObjects(input: MachineLearningInput): Promise<string[]> {
return client.post<string[]>('/object-detection/detect-object', input).then((res) => res.data); return client.post<string[]>('/object-detection/detect-object', input).then((res) => res.data);
} }
encodeImage(input: MachineLearningInput): Promise<number[]> {
return client.post<number[]>('/sentence-transformer/encode-image', input).then((res) => res.data);
}
encodeText(input: string): Promise<number[]> {
return client.post<number[]>('/sentence-transformer/encode-text', { text: input }).then((res) => res.data);
}
} }

View file

@ -1,6 +1,6 @@
import { CollectionCreateSchema } from 'typesense/lib/Typesense/Collections'; import { CollectionCreateSchema } from 'typesense/lib/Typesense/Collections';
export const assetSchemaVersion = 2; export const assetSchemaVersion = 3;
export const assetSchema: CollectionCreateSchema = { export const assetSchema: CollectionCreateSchema = {
name: `assets-v${assetSchemaVersion}`, name: `assets-v${assetSchemaVersion}`,
fields: [ fields: [
@ -29,6 +29,7 @@ export const assetSchema: CollectionCreateSchema = {
// smart info // smart info
{ name: 'smartInfo.objects', type: 'string[]', facet: true, optional: true }, { name: 'smartInfo.objects', type: 'string[]', facet: true, optional: true },
{ name: 'smartInfo.tags', type: 'string[]', facet: true, optional: true }, { name: 'smartInfo.tags', type: 'string[]', facet: true, optional: true },
{ name: 'smartInfo.clipEmbedding', type: 'float[]', facet: false, optional: true, num_dim: 512 },
// computed // computed
{ name: 'geo', type: 'geopoint', facet: false, optional: true }, { name: 'geo', type: 'geopoint', facet: false, optional: true },

View file

@ -16,12 +16,7 @@ import { AlbumEntity, AssetEntity } from '../db';
import { albumSchema } from './schemas/album.schema'; import { albumSchema } from './schemas/album.schema';
import { assetSchema } from './schemas/asset.schema'; import { assetSchema } from './schemas/asset.schema';
interface CustomAssetEntity extends AssetEntity { function removeNil<T extends Dictionary<any>>(item: T): T {
geo?: [number, number];
motion?: boolean;
}
function removeNil<T extends Dictionary<any>>(item: T): Partial<T> {
_.forOwn(item, (value, key) => { _.forOwn(item, (value, key) => {
if (_.isNil(value) || (_.isObject(value) && !_.isDate(value) && _.isEmpty(removeNil(value)))) { if (_.isNil(value) || (_.isObject(value) && !_.isDate(value) && _.isEmpty(removeNil(value)))) {
delete item[key]; delete item[key];
@ -31,6 +26,11 @@ function removeNil<T extends Dictionary<any>>(item: T): Partial<T> {
return item; return item;
} }
interface CustomAssetEntity extends AssetEntity {
geo?: [number, number];
motion?: boolean;
}
const schemaMap: Record<SearchCollection, CollectionCreateSchema> = { const schemaMap: Record<SearchCollection, CollectionCreateSchema> = {
[SearchCollection.ASSETS]: assetSchema, [SearchCollection.ASSETS]: assetSchema,
[SearchCollection.ALBUMS]: albumSchema, [SearchCollection.ALBUMS]: albumSchema,
@ -38,24 +38,9 @@ const schemaMap: Record<SearchCollection, CollectionCreateSchema> = {
const schemas = Object.entries(schemaMap) as [SearchCollection, CollectionCreateSchema][]; const schemas = Object.entries(schemaMap) as [SearchCollection, CollectionCreateSchema][];
interface SearchUpdateQueue<T = any> {
upsert: T[];
delete: string[];
}
@Injectable() @Injectable()
export class TypesenseRepository implements ISearchRepository { export class TypesenseRepository implements ISearchRepository {
private logger = new Logger(TypesenseRepository.name); private logger = new Logger(TypesenseRepository.name);
private queue: Record<SearchCollection, SearchUpdateQueue> = {
[SearchCollection.ASSETS]: {
upsert: [],
delete: [],
},
[SearchCollection.ALBUMS]: {
upsert: [],
delete: [],
},
};
private _client: Client | null = null; private _client: Client | null = null;
private get client(): Client { private get client(): Client {
@ -83,8 +68,6 @@ export class TypesenseRepository implements ISearchRepository {
numRetries: 3, numRetries: 3,
connectionTimeoutSeconds: 10, connectionTimeoutSeconds: 10,
}); });
setInterval(() => this.flush(), 5_000);
} }
async setup(): Promise<void> { async setup(): Promise<void> {
@ -131,48 +114,27 @@ export class TypesenseRepository implements ISearchRepository {
return migrationMap; return migrationMap;
} }
async index(collection: SearchCollection, item: AssetEntity | AlbumEntity, immediate?: boolean): Promise<void> { async importAlbums(items: AlbumEntity[], done: boolean): Promise<void> {
const schema = schemaMap[collection]; await this.import(SearchCollection.ALBUMS, items, done);
if (collection === SearchCollection.ASSETS) {
item = this.patchAsset(item as AssetEntity);
}
if (immediate) {
await this.client.collections(schema.name).documents().upsert(item);
return;
}
this.queue[collection].upsert.push(item);
} }
async delete(collection: SearchCollection, id: string, immediate?: boolean): Promise<void> { async importAssets(items: AssetEntity[], done: boolean): Promise<void> {
const schema = schemaMap[collection]; await this.import(SearchCollection.ASSETS, items, done);
if (immediate) {
await this.client.collections(schema.name).documents().delete(id);
return;
}
this.queue[collection].delete.push(id);
} }
async import(collection: SearchCollection, items: AssetEntity[] | AlbumEntity[], done: boolean): Promise<void> { private async import(
collection: SearchCollection,
items: AlbumEntity[] | AssetEntity[],
done: boolean,
): Promise<void> {
try { try {
const schema = schemaMap[collection]; if (items.length > 0) {
const _items = items.map((item) => { await this.client.collections(schemaMap[collection].name).documents().import(this.patch(collection, items), {
if (collection === SearchCollection.ASSETS) { action: 'upsert',
item = this.patchAsset(item as AssetEntity); dirty_values: 'coerce_or_drop',
} });
// null values are invalid for typesense documents
return removeNil(item);
});
if (_items.length > 0) {
await this.client
.collections(schema.name)
.documents()
.import(_items, { action: 'upsert', dirty_values: 'coerce_or_drop' });
} }
if (done) { if (done) {
await this.updateAlias(collection); await this.updateAlias(collection);
} }
@ -234,71 +196,81 @@ export class TypesenseRepository implements ISearchRepository {
); );
} }
search(collection: SearchCollection.ASSETS, query: string, filter: SearchFilter): Promise<SearchResult<AssetEntity>>; async deleteAlbums(ids: string[]): Promise<void> {
search(collection: SearchCollection.ALBUMS, query: string, filter: SearchFilter): Promise<SearchResult<AlbumEntity>>; await this.delete(SearchCollection.ALBUMS, ids);
async search(collection: SearchCollection, query: string, filters: SearchFilter) {
const alias = await this.client.aliases(collection).retrieve();
const { userId } = filters;
const _filters = [`ownerId:${userId}`];
if (filters.id) {
_filters.push(`id:=${filters.id}`);
}
if (collection === SearchCollection.ASSETS) {
for (const item of schemaMap[collection].fields || []) {
let value = filters[item.name as keyof SearchFilter];
if (Array.isArray(value)) {
value = `[${value.join(',')}]`;
}
if (item.facet && value !== undefined) {
_filters.push(`${item.name}:${value}`);
}
}
this.logger.debug(`Searching query='${query}', filters='${JSON.stringify(_filters)}'`);
const results = await this.client
.collections<AssetEntity>(alias.collection_name)
.documents()
.search({
q: query,
query_by: [
'exifInfo.imageName',
'exifInfo.country',
'exifInfo.state',
'exifInfo.city',
'exifInfo.description',
'smartInfo.tags',
'smartInfo.objects',
].join(','),
filter_by: _filters.join(' && '),
per_page: 250,
sort_by: filters.recent ? 'createdAt:desc' : undefined,
facet_by: this.getFacetFieldNames(SearchCollection.ASSETS),
});
return this.asResponse(results);
}
if (collection === SearchCollection.ALBUMS) {
const results = await this.client
.collections<AlbumEntity>(alias.collection_name)
.documents()
.search({
q: query,
query_by: 'albumName',
filter_by: _filters.join(','),
});
return this.asResponse(results);
}
throw new Error(`Invalid collection: ${collection}`);
} }
private asResponse<T extends DocumentSchema>(results: SearchResponse<T>): SearchResult<T> { async deleteAssets(ids: string[]): Promise<void> {
await this.delete(SearchCollection.ASSETS, ids);
}
async delete(collection: SearchCollection, ids: string[]): Promise<void> {
await this.client
.collections(schemaMap[collection].name)
.documents()
.delete({ filter_by: `id: [${ids.join(',')}]` });
}
async searchAlbums(query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>> {
const alias = await this.client.aliases(SearchCollection.ALBUMS).retrieve();
const results = await this.client
.collections<AlbumEntity>(alias.collection_name)
.documents()
.search({
q: query,
query_by: 'albumName',
filter_by: this.getAlbumFilters(filters),
});
return this.asResponse(results, filters.debug);
}
async searchAssets(query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>> {
const alias = await this.client.aliases(SearchCollection.ASSETS).retrieve();
const results = await this.client
.collections<AssetEntity>(alias.collection_name)
.documents()
.search({
q: query,
query_by: [
'exifInfo.imageName',
'exifInfo.country',
'exifInfo.state',
'exifInfo.city',
'exifInfo.description',
'smartInfo.tags',
'smartInfo.objects',
].join(','),
per_page: 250,
facet_by: this.getFacetFieldNames(SearchCollection.ASSETS),
filter_by: this.getAssetFilters(filters),
sort_by: filters.recent ? 'createdAt:desc' : undefined,
});
return this.asResponse(results, filters.debug);
}
async vectorSearch(input: number[], filters: SearchFilter): Promise<SearchResult<AssetEntity>> {
const alias = await this.client.aliases(SearchCollection.ASSETS).retrieve();
const { results } = await this.client.multiSearch.perform({
searches: [
{
collection: alias.collection_name,
q: '*',
vector_query: `smartInfo.clipEmbedding:([${input.join(',')}], k:100)`,
per_page: 250,
facet_by: this.getFacetFieldNames(SearchCollection.ASSETS),
filter_by: this.getAssetFilters(filters),
} as any,
],
});
return this.asResponse(results[0] as SearchResponse<AssetEntity>, filters.debug);
}
private asResponse<T extends DocumentSchema>(results: SearchResponse<T>, debug?: boolean): SearchResult<T> {
return { return {
page: results.page, page: results.page,
total: results.found, total: results.found,
@ -308,51 +280,23 @@ export class TypesenseRepository implements ISearchRepository {
counts: facet.counts.map((item) => ({ count: item.count, value: item.value })), counts: facet.counts.map((item) => ({ count: item.count, value: item.value })),
fieldName: facet.field_name as string, fieldName: facet.field_name as string,
})), })),
}; debug: debug ? results : undefined,
} as SearchResult<T>;
} }
private async flush() { private handleError(error: any) {
for (const [collection, schema] of schemas) {
if (this.queue[collection].upsert.length > 0) {
try {
const items = this.queue[collection].upsert.map((item) => removeNil(item));
this.logger.debug(`Flushing ${items.length} ${collection} upserts to typesense`);
await this.client
.collections(schema.name)
.documents()
.import(items, { action: 'upsert', dirty_values: 'coerce_or_drop' });
this.queue[collection].upsert = [];
} catch (error) {
this.handleError(error);
}
}
if (this.queue[collection].delete.length > 0) {
try {
const items = this.queue[collection].delete;
this.logger.debug(`Flushing ${items.length} ${collection} deletes to typesense`);
await this.client
.collections(schema.name)
.documents()
.delete({ filter_by: `id: [${items.join(',')}]` });
this.queue[collection].delete = [];
} catch (error) {
this.handleError(error);
}
}
}
}
private handleError(error: any): never {
this.logger.error('Unable to index documents'); this.logger.error('Unable to index documents');
const results = error.importResults || []; const results = error.importResults || [];
for (const result of results) { for (const result of results) {
try { try {
result.document = JSON.parse(result.document); result.document = JSON.parse(result.document);
if (result.document?.smartInfo?.clipEmbedding) {
result.document.smartInfo.clipEmbedding = '<truncated>';
}
} catch {} } catch {}
} }
this.logger.verbose(JSON.stringify(results, null, 2)); this.logger.verbose(JSON.stringify(results, null, 2));
throw error;
} }
private async updateAlias(collection: SearchCollection) { private async updateAlias(collection: SearchCollection) {
@ -373,6 +317,18 @@ export class TypesenseRepository implements ISearchRepository {
} }
} }
private patch(collection: SearchCollection, items: AssetEntity[] | AlbumEntity[]) {
return items.map((item) =>
collection === SearchCollection.ASSETS
? this.patchAsset(item as AssetEntity)
: this.patchAlbum(item as AlbumEntity),
);
}
private patchAlbum(album: AlbumEntity): AlbumEntity {
return removeNil(album);
}
private patchAsset(asset: AssetEntity): CustomAssetEntity { private patchAsset(asset: AssetEntity): CustomAssetEntity {
let custom = asset as CustomAssetEntity; let custom = asset as CustomAssetEntity;
@ -382,9 +338,7 @@ export class TypesenseRepository implements ISearchRepository {
custom = { ...custom, geo: [lat, lng] }; custom = { ...custom, geo: [lat, lng] };
} }
custom = { ...custom, motion: !!asset.livePhotoVideoId }; return removeNil({ ...custom, motion: !!asset.livePhotoVideoId });
return custom;
} }
private getFacetFieldNames(collection: SearchCollection) { private getFacetFieldNames(collection: SearchCollection) {
@ -393,4 +347,41 @@ export class TypesenseRepository implements ISearchRepository {
.map((field) => field.name) .map((field) => field.name)
.join(','); .join(',');
} }
private getAlbumFilters(filters: SearchFilter) {
const { userId } = filters;
const _filters = [`ownerId:${userId}`];
if (filters.id) {
_filters.push(`id:=${filters.id}`);
}
for (const item of albumSchema.fields || []) {
let value = filters[item.name as keyof SearchFilter];
if (Array.isArray(value)) {
value = `[${value.join(',')}]`;
}
if (item.facet && value !== undefined) {
_filters.push(`${item.name}:${value}`);
}
}
return _filters.join(' && ');
}
private getAssetFilters(filters: SearchFilter) {
const _filters = [`ownerId:${filters.userId}`];
if (filters.id) {
_filters.push(`id:=${filters.id}`);
}
for (const item of assetSchema.fields || []) {
let value = filters[item.name as keyof SearchFilter];
if (Array.isArray(value)) {
value = `[${value.join(',')}]`;
}
if (item.facet && value !== undefined) {
_filters.push(`${item.name}:${value}`);
}
}
return _filters.join(' && ');
}
} }

View file

@ -48,7 +48,7 @@
"sanitize-filename": "^1.6.3", "sanitize-filename": "^1.6.3",
"sharp": "^0.28.0", "sharp": "^0.28.0",
"typeorm": "^0.3.11", "typeorm": "^0.3.11",
"typesense": "^1.5.2" "typesense": "^1.5.3"
}, },
"bin": { "bin": {
"immich": "bin/cli.sh" "immich": "bin/cli.sh"
@ -11137,9 +11137,9 @@
} }
}, },
"node_modules/typesense": { "node_modules/typesense": {
"version": "1.5.2", "version": "1.5.3",
"resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.2.tgz", "resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.3.tgz",
"integrity": "sha512-ysARFw+4z3AdSViOACqf7K9TXoP2wAXd5p5uSGTdXW14UYjcEzpV/S/EhMoiC6YdZyrnbDdNsxgWbf+AWJ9Udw==", "integrity": "sha512-eLHBP6AHex04tT+q/a7Uc+dFjIuoKTRpvlsNJwVTyedh4n0qnJxbfoLJBCxzhhZn5eITjEK0oWvVZ5byc3E+Ww==",
"dependencies": { "dependencies": {
"axios": "^0.26.0", "axios": "^0.26.0",
"loglevel": "^1.8.0" "loglevel": "^1.8.0"
@ -20023,9 +20023,9 @@
"devOptional": true "devOptional": true
}, },
"typesense": { "typesense": {
"version": "1.5.2", "version": "1.5.3",
"resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.2.tgz", "resolved": "https://registry.npmjs.org/typesense/-/typesense-1.5.3.tgz",
"integrity": "sha512-ysARFw+4z3AdSViOACqf7K9TXoP2wAXd5p5uSGTdXW14UYjcEzpV/S/EhMoiC6YdZyrnbDdNsxgWbf+AWJ9Udw==", "integrity": "sha512-eLHBP6AHex04tT+q/a7Uc+dFjIuoKTRpvlsNJwVTyedh4n0qnJxbfoLJBCxzhhZn5eITjEK0oWvVZ5byc3E+Ww==",
"requires": { "requires": {
"axios": "^0.26.0", "axios": "^0.26.0",
"loglevel": "^1.8.0" "loglevel": "^1.8.0"

View file

@ -78,7 +78,7 @@
"sanitize-filename": "^1.6.3", "sanitize-filename": "^1.6.3",
"sharp": "^0.28.0", "sharp": "^0.28.0",
"typeorm": "^0.3.11", "typeorm": "^0.3.11",
"typesense": "^1.5.2" "typesense": "^1.5.3"
}, },
"devDependencies": { "devDependencies": {
"@nestjs/cli": "^9.1.8", "@nestjs/cli": "^9.1.8",

View file

@ -6739,22 +6739,10 @@ export const SearchApiAxiosParamCreator = function (configuration?: Configuratio
}, },
/** /**
* *
* @param {string} [query]
* @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type]
* @param {boolean} [isFavorite]
* @param {string} [exifInfoCity]
* @param {string} [exifInfoState]
* @param {string} [exifInfoCountry]
* @param {string} [exifInfoMake]
* @param {string} [exifInfoModel]
* @param {Array<string>} [smartInfoObjects]
* @param {Array<string>} [smartInfoTags]
* @param {boolean} [recent]
* @param {boolean} [motion]
* @param {*} [options] Override http request option. * @param {*} [options] Override http request option.
* @throws {RequiredError} * @throws {RequiredError}
*/ */
search: async (query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options: AxiosRequestConfig = {}): Promise<RequestArgs> => { search: async (options: AxiosRequestConfig = {}): Promise<RequestArgs> => {
const localVarPath = `/search`; const localVarPath = `/search`;
// use dummy base URL string because the URL constructor only accepts absolute URLs. // use dummy base URL string because the URL constructor only accepts absolute URLs.
const localVarUrlObj = new URL(localVarPath, DUMMY_BASE_URL); const localVarUrlObj = new URL(localVarPath, DUMMY_BASE_URL);
@ -6773,54 +6761,6 @@ export const SearchApiAxiosParamCreator = function (configuration?: Configuratio
// authentication cookie required // authentication cookie required
if (query !== undefined) {
localVarQueryParameter['query'] = query;
}
if (type !== undefined) {
localVarQueryParameter['type'] = type;
}
if (isFavorite !== undefined) {
localVarQueryParameter['isFavorite'] = isFavorite;
}
if (exifInfoCity !== undefined) {
localVarQueryParameter['exifInfo.city'] = exifInfoCity;
}
if (exifInfoState !== undefined) {
localVarQueryParameter['exifInfo.state'] = exifInfoState;
}
if (exifInfoCountry !== undefined) {
localVarQueryParameter['exifInfo.country'] = exifInfoCountry;
}
if (exifInfoMake !== undefined) {
localVarQueryParameter['exifInfo.make'] = exifInfoMake;
}
if (exifInfoModel !== undefined) {
localVarQueryParameter['exifInfo.model'] = exifInfoModel;
}
if (smartInfoObjects) {
localVarQueryParameter['smartInfo.objects'] = smartInfoObjects;
}
if (smartInfoTags) {
localVarQueryParameter['smartInfo.tags'] = smartInfoTags;
}
if (recent !== undefined) {
localVarQueryParameter['recent'] = recent;
}
if (motion !== undefined) {
localVarQueryParameter['motion'] = motion;
}
setSearchParams(localVarUrlObj, localVarQueryParameter); setSearchParams(localVarUrlObj, localVarQueryParameter);
@ -6862,23 +6802,11 @@ export const SearchApiFp = function(configuration?: Configuration) {
}, },
/** /**
* *
* @param {string} [query]
* @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type]
* @param {boolean} [isFavorite]
* @param {string} [exifInfoCity]
* @param {string} [exifInfoState]
* @param {string} [exifInfoCountry]
* @param {string} [exifInfoMake]
* @param {string} [exifInfoModel]
* @param {Array<string>} [smartInfoObjects]
* @param {Array<string>} [smartInfoTags]
* @param {boolean} [recent]
* @param {boolean} [motion]
* @param {*} [options] Override http request option. * @param {*} [options] Override http request option.
* @throws {RequiredError} * @throws {RequiredError}
*/ */
async search(query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options?: AxiosRequestConfig): Promise<(axios?: AxiosInstance, basePath?: string) => AxiosPromise<SearchResponseDto>> { async search(options?: AxiosRequestConfig): Promise<(axios?: AxiosInstance, basePath?: string) => AxiosPromise<SearchResponseDto>> {
const localVarAxiosArgs = await localVarAxiosParamCreator.search(query, type, isFavorite, exifInfoCity, exifInfoState, exifInfoCountry, exifInfoMake, exifInfoModel, smartInfoObjects, smartInfoTags, recent, motion, options); const localVarAxiosArgs = await localVarAxiosParamCreator.search(options);
return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration); return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration);
}, },
} }
@ -6909,23 +6837,11 @@ export const SearchApiFactory = function (configuration?: Configuration, basePat
}, },
/** /**
* *
* @param {string} [query]
* @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type]
* @param {boolean} [isFavorite]
* @param {string} [exifInfoCity]
* @param {string} [exifInfoState]
* @param {string} [exifInfoCountry]
* @param {string} [exifInfoMake]
* @param {string} [exifInfoModel]
* @param {Array<string>} [smartInfoObjects]
* @param {Array<string>} [smartInfoTags]
* @param {boolean} [recent]
* @param {boolean} [motion]
* @param {*} [options] Override http request option. * @param {*} [options] Override http request option.
* @throws {RequiredError} * @throws {RequiredError}
*/ */
search(query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options?: any): AxiosPromise<SearchResponseDto> { search(options?: any): AxiosPromise<SearchResponseDto> {
return localVarFp.search(query, type, isFavorite, exifInfoCity, exifInfoState, exifInfoCountry, exifInfoMake, exifInfoModel, smartInfoObjects, smartInfoTags, recent, motion, options).then((request) => request(axios, basePath)); return localVarFp.search(options).then((request) => request(axios, basePath));
}, },
}; };
}; };
@ -6959,24 +6875,12 @@ export class SearchApi extends BaseAPI {
/** /**
* *
* @param {string} [query]
* @param {'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER'} [type]
* @param {boolean} [isFavorite]
* @param {string} [exifInfoCity]
* @param {string} [exifInfoState]
* @param {string} [exifInfoCountry]
* @param {string} [exifInfoMake]
* @param {string} [exifInfoModel]
* @param {Array<string>} [smartInfoObjects]
* @param {Array<string>} [smartInfoTags]
* @param {boolean} [recent]
* @param {boolean} [motion]
* @param {*} [options] Override http request option. * @param {*} [options] Override http request option.
* @throws {RequiredError} * @throws {RequiredError}
* @memberof SearchApi * @memberof SearchApi
*/ */
public search(query?: string, type?: 'IMAGE' | 'VIDEO' | 'AUDIO' | 'OTHER', isFavorite?: boolean, exifInfoCity?: string, exifInfoState?: string, exifInfoCountry?: string, exifInfoMake?: string, exifInfoModel?: string, smartInfoObjects?: Array<string>, smartInfoTags?: Array<string>, recent?: boolean, motion?: boolean, options?: AxiosRequestConfig) { public search(options?: AxiosRequestConfig) {
return SearchApiFp(this.configuration).search(query, type, isFavorite, exifInfoCity, exifInfoState, exifInfoCountry, exifInfoMake, exifInfoModel, smartInfoObjects, smartInfoTags, recent, motion, options).then((request) => request(this.axios, this.basePath)); return SearchApiFp(this.configuration).search(options).then((request) => request(this.axios, this.basePath));
} }
} }

View file

@ -15,7 +15,8 @@
function onSearch() { function onSearch() {
const params = new URLSearchParams({ const params = new URLSearchParams({
q: value q: value,
clip: 'true'
}); });
goto(`${AppRoute.SEARCH}?${params}`, { replaceState: replaceHistoryState }); goto(`${AppRoute.SEARCH}?${params}`, { replaceState: replaceHistoryState });

View file

@ -7,22 +7,9 @@ export const load = (async ({ locals, parent, url }) => {
throw redirect(302, '/auth/login'); throw redirect(302, '/auth/login');
} }
const term = url.searchParams.get('q') || undefined; const term = url.searchParams.get('q') || url.searchParams.get('query') || undefined;
const { data: results } = await locals.api.searchApi.search(
term, const { data: results } = await locals.api.searchApi.search({ params: url.searchParams });
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
{ params: url.searchParams }
);
return { return {
user, user,