mirror of
https://github.com/immich-app/immich.git
synced 2025-03-29 03:09:38 +01:00
358 lines
12 KiB
TypeScript
358 lines
12 KiB
TypeScript
import { Injectable } from '@nestjs/common';
|
|
import { Kysely, OrderByDirectionExpression, sql } from 'kysely';
|
|
import { InjectKysely } from 'nestjs-kysely';
|
|
import { randomUUID } from 'node:crypto';
|
|
import { DB } from 'src/db';
|
|
import { DummyValue, GenerateSql } from 'src/decorators';
|
|
import { AssetEntity, searchAssetBuilder } from 'src/entities/asset.entity';
|
|
import { GeodataPlacesEntity } from 'src/entities/geodata-places.entity';
|
|
import { AssetType } from 'src/enum';
|
|
import {
|
|
AssetDuplicateSearch,
|
|
AssetSearchOptions,
|
|
FaceEmbeddingSearch,
|
|
GetCameraMakesOptions,
|
|
GetCameraModelsOptions,
|
|
GetCitiesOptions,
|
|
GetStatesOptions,
|
|
ISearchRepository,
|
|
SearchPaginationOptions,
|
|
SmartSearchOptions,
|
|
} from 'src/interfaces/search.interface';
|
|
import { LoggingRepository } from 'src/repositories/logging.repository';
|
|
import { anyUuid, asUuid } from 'src/utils/database';
|
|
import { Paginated } from 'src/utils/pagination';
|
|
import { isValidInteger } from 'src/validation';
|
|
|
|
@Injectable()
|
|
export class SearchRepository implements ISearchRepository {
|
|
constructor(
|
|
private logger: LoggingRepository,
|
|
@InjectKysely() private db: Kysely<DB>,
|
|
) {
|
|
this.logger.setContext(SearchRepository.name);
|
|
}
|
|
|
|
@GenerateSql({
|
|
params: [
|
|
{ page: 1, size: 100 },
|
|
{
|
|
takenAfter: DummyValue.DATE,
|
|
lensModel: DummyValue.STRING,
|
|
withStacked: true,
|
|
isFavorite: true,
|
|
userIds: [DummyValue.UUID],
|
|
},
|
|
],
|
|
})
|
|
async searchMetadata(pagination: SearchPaginationOptions, options: AssetSearchOptions): Paginated<AssetEntity> {
|
|
const orderDirection = (options.orderDirection?.toLowerCase() || 'desc') as OrderByDirectionExpression;
|
|
const items = await searchAssetBuilder(this.db, options)
|
|
.orderBy('assets.fileCreatedAt', orderDirection)
|
|
.limit(pagination.size + 1)
|
|
.offset((pagination.page - 1) * pagination.size)
|
|
.execute();
|
|
const hasNextPage = items.length > pagination.size;
|
|
items.splice(pagination.size);
|
|
return { items: items as any as AssetEntity[], hasNextPage };
|
|
}
|
|
|
|
@GenerateSql({
|
|
params: [
|
|
100,
|
|
{
|
|
takenAfter: DummyValue.DATE,
|
|
lensModel: DummyValue.STRING,
|
|
withStacked: true,
|
|
isFavorite: true,
|
|
userIds: [DummyValue.UUID],
|
|
},
|
|
],
|
|
})
|
|
async searchRandom(size: number, options: AssetSearchOptions): Promise<AssetEntity[]> {
|
|
const uuid = randomUUID();
|
|
const builder = searchAssetBuilder(this.db, options);
|
|
const lessThan = builder.where('assets.id', '<', uuid).orderBy('assets.id').limit(size);
|
|
const greaterThan = builder.where('assets.id', '>', uuid).orderBy('assets.id').limit(size);
|
|
const { rows } = await sql`${lessThan} union all ${greaterThan} limit ${size}`.execute(this.db);
|
|
return rows as any as AssetEntity[];
|
|
}
|
|
|
|
@GenerateSql({
|
|
params: [
|
|
{ page: 1, size: 200 },
|
|
{
|
|
takenAfter: DummyValue.DATE,
|
|
embedding: DummyValue.VECTOR,
|
|
lensModel: DummyValue.STRING,
|
|
withStacked: true,
|
|
isFavorite: true,
|
|
userIds: [DummyValue.UUID],
|
|
},
|
|
],
|
|
})
|
|
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions): Paginated<AssetEntity> {
|
|
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
|
|
throw new Error(`Invalid value for 'size': ${pagination.size}`);
|
|
}
|
|
|
|
const items = (await searchAssetBuilder(this.db, options)
|
|
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
|
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
|
|
.limit(pagination.size + 1)
|
|
.offset((pagination.page - 1) * pagination.size)
|
|
.execute()) as any as AssetEntity[];
|
|
|
|
const hasNextPage = items.length > pagination.size;
|
|
items.splice(pagination.size);
|
|
return { items, hasNextPage };
|
|
}
|
|
|
|
@GenerateSql({
|
|
params: [
|
|
{
|
|
assetId: DummyValue.UUID,
|
|
embedding: DummyValue.VECTOR,
|
|
maxDistance: 0.6,
|
|
type: AssetType.IMAGE,
|
|
userIds: [DummyValue.UUID],
|
|
},
|
|
],
|
|
})
|
|
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
|
|
return this.db
|
|
.with('cte', (qb) =>
|
|
qb
|
|
.selectFrom('assets')
|
|
.select([
|
|
'assets.id as assetId',
|
|
'assets.duplicateId',
|
|
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
|
|
])
|
|
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
|
.where('assets.ownerId', '=', anyUuid(userIds))
|
|
.where('assets.deletedAt', 'is', null)
|
|
.where('assets.isVisible', '=', true)
|
|
.where('assets.type', '=', type)
|
|
.where('assets.id', '!=', asUuid(assetId))
|
|
.orderBy(sql`smart_search.embedding <=> ${embedding}`)
|
|
.limit(64),
|
|
)
|
|
.selectFrom('cte')
|
|
.selectAll()
|
|
.where('cte.distance', '<=', maxDistance as number)
|
|
.execute();
|
|
}
|
|
|
|
@GenerateSql({
|
|
params: [
|
|
{
|
|
userIds: [DummyValue.UUID],
|
|
embedding: DummyValue.VECTOR,
|
|
numResults: 10,
|
|
maxDistance: 0.6,
|
|
},
|
|
],
|
|
})
|
|
searchFaces({ userIds, embedding, numResults, maxDistance, hasPerson }: FaceEmbeddingSearch) {
|
|
if (!isValidInteger(numResults, { min: 1, max: 1000 })) {
|
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
|
}
|
|
|
|
return this.db
|
|
.with('cte', (qb) =>
|
|
qb
|
|
.selectFrom('asset_faces')
|
|
.select([
|
|
'asset_faces.id',
|
|
'asset_faces.personId',
|
|
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
|
|
])
|
|
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
|
|
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
|
|
.where('assets.ownerId', '=', anyUuid(userIds))
|
|
.where('assets.deletedAt', 'is', null)
|
|
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
|
|
.orderBy(sql`face_search.embedding <=> ${embedding}`)
|
|
.limit(numResults),
|
|
)
|
|
.selectFrom('cte')
|
|
.selectAll()
|
|
.where('cte.distance', '<=', maxDistance)
|
|
.execute();
|
|
}
|
|
|
|
@GenerateSql({ params: [DummyValue.STRING] })
|
|
searchPlaces(placeName: string): Promise<GeodataPlacesEntity[]> {
|
|
return this.db
|
|
.selectFrom('geodata_places')
|
|
.selectAll()
|
|
.where(
|
|
() =>
|
|
// kysely doesn't support trigram %>> or <->>> operators
|
|
sql`
|
|
f_unaccent(name) %>> f_unaccent(${placeName}) or
|
|
f_unaccent("admin2Name") %>> f_unaccent(${placeName}) or
|
|
f_unaccent("admin1Name") %>> f_unaccent(${placeName}) or
|
|
f_unaccent("alternateNames") %>> f_unaccent(${placeName})
|
|
`,
|
|
)
|
|
.orderBy(
|
|
sql`
|
|
coalesce(f_unaccent(name) <->>> f_unaccent(${placeName}), 0.1) +
|
|
coalesce(f_unaccent("admin2Name") <->>> f_unaccent(${placeName}), 0.1) +
|
|
coalesce(f_unaccent("admin1Name") <->>> f_unaccent(${placeName}), 0.1) +
|
|
coalesce(f_unaccent("alternateNames") <->>> f_unaccent(${placeName}), 0.1)
|
|
`,
|
|
)
|
|
.limit(20)
|
|
.execute() as Promise<GeodataPlacesEntity[]>;
|
|
}
|
|
|
|
@GenerateSql({ params: [[DummyValue.UUID]] })
|
|
getAssetsByCity(userIds: string[]): Promise<AssetEntity[]> {
|
|
return this.db
|
|
.withRecursive('cte', (qb) => {
|
|
const base = qb
|
|
.selectFrom('exif')
|
|
.select(['city', 'assetId'])
|
|
.innerJoin('assets', 'assets.id', 'exif.assetId')
|
|
.where('assets.ownerId', '=', anyUuid(userIds))
|
|
.where('assets.isVisible', '=', true)
|
|
.where('assets.isArchived', '=', false)
|
|
.where('assets.type', '=', 'IMAGE')
|
|
.where('assets.deletedAt', 'is', null)
|
|
.orderBy('city')
|
|
.limit(1);
|
|
|
|
const recursive = qb
|
|
.selectFrom('cte')
|
|
.select(['l.city', 'l.assetId'])
|
|
.innerJoinLateral(
|
|
(qb) =>
|
|
qb
|
|
.selectFrom('exif')
|
|
.select(['city', 'assetId'])
|
|
.innerJoin('assets', 'assets.id', 'exif.assetId')
|
|
.where('assets.ownerId', '=', anyUuid(userIds))
|
|
.where('assets.isVisible', '=', true)
|
|
.where('assets.isArchived', '=', false)
|
|
.where('assets.type', '=', 'IMAGE')
|
|
.where('assets.deletedAt', 'is', null)
|
|
.whereRef('exif.city', '>', 'cte.city')
|
|
.orderBy('city')
|
|
.limit(1)
|
|
.as('l'),
|
|
(join) => join.onTrue(),
|
|
);
|
|
|
|
return sql<{ city: string; assetId: string }>`(${base} union all ${recursive})`;
|
|
})
|
|
.selectFrom('assets')
|
|
.innerJoin('exif', 'assets.id', 'exif.assetId')
|
|
.innerJoin('cte', 'assets.id', 'cte.assetId')
|
|
.selectAll('assets')
|
|
.select((eb) => eb.fn('to_jsonb', [eb.table('exif')]).as('exifInfo'))
|
|
.orderBy('exif.city')
|
|
.execute() as any as Promise<AssetEntity[]>;
|
|
}
|
|
|
|
async upsert(assetId: string, embedding: string): Promise<void> {
|
|
await this.db
|
|
.insertInto('smart_search')
|
|
.values({ assetId: asUuid(assetId), embedding } as any)
|
|
.onConflict((oc) => oc.column('assetId').doUpdateSet({ embedding } as any))
|
|
.execute();
|
|
}
|
|
|
|
async getDimensionSize(): Promise<number> {
|
|
const { rows } = await sql<{ dimsize: number }>`
|
|
select atttypmod as dimsize
|
|
from pg_attribute f
|
|
join pg_class c ON c.oid = f.attrelid
|
|
where c.relkind = 'r'::char
|
|
and f.attnum > 0
|
|
and c.relname = 'smart_search'
|
|
and f.attname = 'embedding'
|
|
`.execute(this.db);
|
|
|
|
const dimSize = rows[0]['dimsize'];
|
|
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
|
throw new Error(`Could not retrieve CLIP dimension size`);
|
|
}
|
|
return dimSize;
|
|
}
|
|
|
|
setDimensionSize(dimSize: number): Promise<void> {
|
|
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
|
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
|
}
|
|
|
|
return this.db.transaction().execute(async (trx) => {
|
|
await sql`truncate ${sql.table('smart_search')}`.execute(trx);
|
|
await trx.schema
|
|
.alterTable('smart_search')
|
|
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
|
.execute();
|
|
await sql`reindex index clip_index`.execute(trx);
|
|
});
|
|
}
|
|
|
|
async deleteAllSearchEmbeddings(): Promise<void> {
|
|
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
|
}
|
|
|
|
async getCountries(userIds: string[]): Promise<string[]> {
|
|
const res = await this.getExifField('country', userIds).execute();
|
|
return res.map((row) => row.country!);
|
|
}
|
|
|
|
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
|
|
async getStates(userIds: string[], { country }: GetStatesOptions): Promise<string[]> {
|
|
const res = await this.getExifField('state', userIds)
|
|
.$if(!!country, (qb) => qb.where('country', '=', country!))
|
|
.execute();
|
|
|
|
return res.map((row) => row.state!);
|
|
}
|
|
|
|
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING, DummyValue.STRING] })
|
|
async getCities(userIds: string[], { country, state }: GetCitiesOptions): Promise<string[]> {
|
|
const res = await this.getExifField('city', userIds)
|
|
.$if(!!country, (qb) => qb.where('country', '=', country!))
|
|
.$if(!!state, (qb) => qb.where('state', '=', state!))
|
|
.execute();
|
|
|
|
return res.map((row) => row.city!);
|
|
}
|
|
|
|
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
|
|
async getCameraMakes(userIds: string[], { model }: GetCameraMakesOptions): Promise<string[]> {
|
|
const res = await this.getExifField('make', userIds)
|
|
.$if(!!model, (qb) => qb.where('model', '=', model!))
|
|
.execute();
|
|
|
|
return res.map((row) => row.make!);
|
|
}
|
|
|
|
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
|
|
async getCameraModels(userIds: string[], { make }: GetCameraModelsOptions): Promise<string[]> {
|
|
const res = await this.getExifField('model', userIds)
|
|
.$if(!!make, (qb) => qb.where('make', '=', make!))
|
|
.execute();
|
|
|
|
return res.map((row) => row.model!);
|
|
}
|
|
|
|
private getExifField<K extends 'city' | 'state' | 'country' | 'make' | 'model'>(field: K, userIds: string[]) {
|
|
return this.db
|
|
.selectFrom('exif')
|
|
.select(field)
|
|
.distinctOn(field)
|
|
.innerJoin('assets', 'assets.id', 'exif.assetId')
|
|
.where('ownerId', '=', anyUuid(userIds))
|
|
.where('isVisible', '=', true)
|
|
.where('deletedAt', 'is', null)
|
|
.where(field, 'is not', null);
|
|
}
|
|
}
|